mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
feat(tools/bigquery-analyze-contribution): Add analyze contribution tool (#1223)
This tool creates a contribution analysis model and use ml.get_insights to get the results. --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -43,6 +43,7 @@ import (
|
||||
|
||||
// Import tool packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
|
||||
@@ -1359,7 +1359,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"bigquery-database-tools": tools.ToolsetConfig{
|
||||
Name: "bigquery-database-tools",
|
||||
ToolNames: []string{"ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
title: "bigquery-analyze-contribution"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "bigquery-analyze-contribution" tool performs contribution analysis in BigQuery.
|
||||
aliases:
|
||||
- /resources/tools/bigquery-analyze-contribution
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `bigquery-analyze-contribution` tool performs contribution analysis in BigQuery by creating a temporary `CONTRIBUTION_ANALYSIS` model and then querying it with `ML.GET_INSIGHTS` to find top contributors for a given metric.
|
||||
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-analyze-contribution` takes the following parameters:
|
||||
|
||||
- **input_data** (string, required): The data that contain the test and control data to analyze. This can be a fully qualified BigQuery table ID (e.g., `my-project.my_dataset.my_table`) or a SQL query that returns the data.
|
||||
- **contribution_metric** (string, required): The name of the column that contains the metric to analyze. This can be SUM(metric_column_name), SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name) or SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name) depending the type of metric to analyze.
|
||||
- **is_test_col** (string, required): The name of the column that identifies whether a row is in the test or control group. The column must contain boolean values.
|
||||
- **dimension_id_cols** (array of strings, optional): An array of column names that uniquely identify each dimension.
|
||||
- **top_k_insights_by_apriori_support** (integer, optional): The number of top insights to return, ranked by apriori support. Default to '30'.
|
||||
- **pruning_method** (string, optional): The method to use for pruning redundant insights. Can be `'NO_PRUNING'` or `'PRUNE_REDUNDANT_INSIGHTS'`. Defaults to `'PRUNE_REDUNDANT_INSIGHTS'`.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
contribution_analyzer:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: my-bigquery-source
|
||||
description: Use this tool to run contribution analysis on a dataset in BigQuery.
|
||||
```
|
||||
|
||||
## Sample Prompt
|
||||
You can prepare a sample table following https://cloud.google.com/bigquery/docs/get-contribution-analysis-insights.
|
||||
And use the following sample prompts to call this tool:
|
||||
|
||||
- What drives the changes in sales in the table `bqml_tutorial.iowa_liquor_sales_sum_data`? Use the project id myproject.
|
||||
- Analyze the contribution for the `total_sales` metric in the table `bqml_tutorial.iowa_liquor_sales_sum_data`. The test group is identified by the `is_test` column. The dimensions are `store_name`, `city`, `vendor_name`, `category_name` and `item_description`.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:--------:|:------------:|------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery-analyze-contribution". |
|
||||
| source | string | true | Name of the source the tool should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
bigquery-source:
|
||||
kind: "bigquery"
|
||||
@@ -18,6 +19,11 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: bigquery-source
|
||||
description: Use this tool to analyze the contribution about changes to key metrics in multi-dimensional data.
|
||||
|
||||
ask_data_insights:
|
||||
kind: bigquery-conversational-analytics
|
||||
source: bigquery-source
|
||||
@@ -58,6 +64,7 @@ tools:
|
||||
|
||||
toolsets:
|
||||
bigquery-database-tools:
|
||||
- analyze_contribution
|
||||
- ask_data_insights
|
||||
- execute_sql
|
||||
- forecast
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryanalyzecontribution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-analyze-contribution"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
inputDataParameter := tools.NewStringParameter("input_data",
|
||||
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
|
||||
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
|
||||
`The name of the column that contains the metric to analyze.
|
||||
Provides the expression to use to calculate the metric you are analyzing.
|
||||
To calculate a summable metric, the expression must be in the form SUM(metric_column_name),
|
||||
where metric_column_name is a numeric data type.
|
||||
|
||||
To calculate a summable ratio metric, the expression must be in the form
|
||||
SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name),
|
||||
where numerator_metric_column_name and denominator_metric_column_name are numeric data types.
|
||||
|
||||
To calculate a summable by category metric, the expression must be in the form
|
||||
SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name). The summed column must be a numeric data type.
|
||||
The categorical column must have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.`)
|
||||
isTestColParameter := tools.NewStringParameter("is_test_col",
|
||||
"The name of the column that identifies whether a row is in the test or control group.")
|
||||
dimensionIDColsParameter := tools.NewArrayParameterWithRequired("dimension_id_cols",
|
||||
"An array of column names that uniquely identify each dimension.", false, tools.NewStringParameter("dimension_id_col", "A dimension column name."))
|
||||
topKInsightsParameter := tools.NewIntParameterWithDefault("top_k_insights_by_apriori_support", 30,
|
||||
"The number of top insights to return, ranked by apriori support.")
|
||||
pruningMethodParameter := tools.NewStringParameterWithDefault("pruning_method", "PRUNE_REDUNDANT_INSIGHTS",
|
||||
"The method to use for pruning redundant insights. Can be 'NO_PRUNING' or 'PRUNE_REDUNDANT_INSIGHTS'.")
|
||||
|
||||
parameters := tools.Parameters{
|
||||
inputDataParameter,
|
||||
contributionMetricParameter,
|
||||
isTestColParameter,
|
||||
dimensionIDColsParameter,
|
||||
topKInsightsParameter,
|
||||
pruningMethodParameter,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
|
||||
var options []string
|
||||
options = append(options, "MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'")
|
||||
options = append(options, fmt.Sprintf("CONTRIBUTION_METRIC = '%s'", paramsMap["contribution_metric"]))
|
||||
options = append(options, fmt.Sprintf("IS_TEST_COL = '%s'", paramsMap["is_test_col"]))
|
||||
|
||||
if val, ok := paramsMap["dimension_id_cols"]; ok {
|
||||
if cols, ok := val.([]any); ok {
|
||||
var strCols []string
|
||||
for _, c := range cols {
|
||||
strCols = append(strCols, fmt.Sprintf("'%s'", c))
|
||||
}
|
||||
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"])
|
||||
}
|
||||
}
|
||||
if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok {
|
||||
options = append(options, fmt.Sprintf("TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = %v", val))
|
||||
}
|
||||
if val, ok := paramsMap["pruning_method"].(string); ok {
|
||||
upperVal := strings.ToUpper(val)
|
||||
if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" {
|
||||
return nil, fmt.Errorf("invalid pruning_method: %s", val)
|
||||
}
|
||||
options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal))
|
||||
}
|
||||
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
||||
}
|
||||
|
||||
// Use temp model to skip the clean up at the end. To use TEMP MODEL, queries have to be
|
||||
// in the same BigQuery session.
|
||||
createModelSQL := fmt.Sprintf("CREATE TEMP MODEL %s OPTIONS(%s) AS %s",
|
||||
modelID,
|
||||
strings.Join(options, ", "),
|
||||
inputDataSource,
|
||||
)
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
createModelQuery := bqClient.Query(createModelSQL)
|
||||
createModelQuery.CreateSession = true
|
||||
createModelJob, err := createModelQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start create model job: %w", err)
|
||||
}
|
||||
|
||||
status, err := createModelJob.Wait(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for create model job: %w", err)
|
||||
}
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("create model job failed: %w", err)
|
||||
}
|
||||
|
||||
if status.Statistics == nil || status.Statistics.SessionInfo == nil || status.Statistics.SessionInfo.SessionID == "" {
|
||||
return nil, fmt.Errorf("failed to create a BigQuery session")
|
||||
}
|
||||
sessionID := status.Statistics.SessionInfo.SessionID
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
|
||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||
getInsightsQuery.QueryConfig.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: sessionID},
|
||||
}
|
||||
|
||||
job, err := getInsightsQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute get insights query: %w", err)
|
||||
}
|
||||
it, err := job.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read query results: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err := it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate through query results: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for key, value := range row {
|
||||
vMap[key] = value
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
// executes but returns zero rows.
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryanalyzecontribution_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQueryAnalyzeContribution(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigqueryanalyzecontribution.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-analyze-contribution",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -113,6 +113,12 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
tableNameAnalyzeContribution := fmt.Sprintf("`%s.%s.analyze_contribution_table_%s`",
|
||||
BigqueryProject,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||
@@ -133,6 +139,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
|
||||
defer teardownTable4(t)
|
||||
|
||||
// set up data for analyze contribution tool
|
||||
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
|
||||
teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
|
||||
defer teardownTable5(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||
@@ -182,6 +193,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
|
||||
runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName)
|
||||
runBigQueryForecastToolInvokeTest(t, tableNameForecast)
|
||||
runBigQueryAnalyzeContributionToolInvokeTest(t, tableNameAnalyzeContribution)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
@@ -341,7 +353,7 @@ func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryap
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (ts TIMESTAMP, data FLOAT64, id STRING);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (ts, data, id) VALUES
|
||||
INSERT INTO %s (ts, data, id) VALUES
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?),
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?);`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
@@ -355,6 +367,26 @@ func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryap
|
||||
return createStatement, insertStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryAnalyzeContributionToolInfo returns statements and params for the analyze-contribution tool.
|
||||
func getBigQueryAnalyzeContributionToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (dim1 STRING, dim2 STRING, is_test BOOL, metric FLOAT64);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (dim1, dim2, is_test, metric) VALUES
|
||||
(?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?);`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: "a"}, {Value: "x"}, {Value: true}, {Value: 100.0},
|
||||
{Value: "a"}, {Value: "x"}, {Value: false}, {Value: 110.0},
|
||||
{Value: "a"}, {Value: "y"}, {Value: true}, {Value: 120.0},
|
||||
{Value: "a"}, {Value: "y"}, {Value: false}, {Value: 100.0},
|
||||
{Value: "b"}, {Value: "x"}, {Value: true}, {Value: 40.0},
|
||||
{Value: "b"}, {Value: "x"}, {Value: false}, {Value: 100.0},
|
||||
{Value: "b"}, {Value: "y"}, {Value: true}, {Value: 60.0},
|
||||
{Value: "b"}, {Value: "y"}, {Value: false}, {Value: 60.0},
|
||||
}
|
||||
return createStatement, insertStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
@@ -482,6 +514,24 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
"source": "my-client-auth-source",
|
||||
"description": "Tool to forecast time series data with auth.",
|
||||
}
|
||||
tools["my-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution.",
|
||||
}
|
||||
tools["my-auth-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution with auth.",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
tools["my-client-auth-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-client-auth-source",
|
||||
"description": "Tool to analyze contribution with auth.",
|
||||
}
|
||||
tools["my-list-dataset-ids-tool"] = map[string]any{
|
||||
"kind": "bigquery-list-dataset-ids",
|
||||
"source": "my-instance",
|
||||
@@ -1051,6 +1101,127 @@ func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryAnalyzeContributionToolInvokeTest(t *testing.T, tableName string) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
dataTable := strings.ReplaceAll(tableName, "`", "")
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-analyze-contribution-tool without required params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s"}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-analyze-contribution-tool with table",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-analyze-contribution-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-analyze-contribution-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"Authorization": accessToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("expected %q to contain %q, but it did not", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
|
||||
Reference in New Issue
Block a user