mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat(tools/bigquery-forecast): Add bigqueryforecast tool (#1148)
This tool wraps the BigQuery's AI.FORECAST function to do the time series forecasting. Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -44,6 +44,7 @@ import (
|
||||
// Import tool packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids"
|
||||
|
||||
@@ -1210,7 +1210,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"bigquery-database-tools": tools.ToolsetConfig{
|
||||
Name: "bigquery-database-tools",
|
||||
ToolNames: []string{"execute_sql", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
ToolNames: []string{"execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
49
docs/en/resources/tools/bigquery/bigquery-forecast.md
Normal file
49
docs/en/resources/tools/bigquery/bigquery-forecast.md
Normal file
@@ -0,0 +1,49 @@
|
||||
---
|
||||
title: "bigquery-forecast"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "bigquery-forecast" tool forecasts time series data in BigQuery.
|
||||
aliases:
|
||||
- /resources/tools/bigquery-forecast
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `bigquery-forecast` tool forecasts time series data in BigQuery.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-forecast` constructs and executes a `SELECT * FROM AI.FORECAST(...)` query based on the provided parameters:
|
||||
|
||||
- **history_data** (string, required): This specifies the source of the historical time series data. It can be either a fully qualified BigQuery table ID (e.g., my-project.my_dataset.my_table) or a SQL query that returns the data.
|
||||
- **timestamp_col** (string, required): The name of the column in your history_data that contains the timestamps.
|
||||
- **data_col** (string, required): The name of the column in your history_data that contains the numeric values to be forecasted.
|
||||
- **id_cols** (array of strings, optional): If you are forecasting multiple time series at once (e.g., sales for different products), this parameter takes an array of column names that uniquely identify each series. It defaults to an empty array if not provided.
|
||||
- **horizon** (integer, optional): The number of future time steps you want to predict. It defaults to 10 if not specified.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
forecast_tool:
|
||||
kind: bigquery-forecast
|
||||
source: my-bigquery-source
|
||||
description: Use this tool to forecast time series data in BigQuery.
|
||||
```
|
||||
|
||||
## Sample Prompt
|
||||
You can use the following sample prompts to call this tool:
|
||||
|
||||
- Can you forecast the history time series data in bigquery table `bqml_tutorial.google_analytic`? Use project_id `myproject`.
|
||||
- What are the future `total_visits` in bigquery table `bqml_tutorial.google_analytic`?
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery-forecast". |
|
||||
| source | string | true | Name of the source the forecast tool should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -9,6 +9,11 @@ tools:
|
||||
source: bigquery-source
|
||||
description: Use this tool to execute sql statement.
|
||||
|
||||
forecast:
|
||||
kind: bigquery-forecast
|
||||
source: bigquery-source
|
||||
description: Use this tool to forecast time series data.
|
||||
|
||||
get_dataset_info:
|
||||
kind: bigquery-get-dataset-info
|
||||
source: bigquery-source
|
||||
@@ -32,6 +37,7 @@ tools:
|
||||
toolsets:
|
||||
bigquery-database-tools:
|
||||
- execute_sql
|
||||
- forecast
|
||||
- get_dataset_info
|
||||
- get_table_info
|
||||
- list_dataset_ids
|
||||
|
||||
247
internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go
Normal file
247
internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go
Normal file
@@ -0,0 +1,247 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryforecast
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-forecast"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
historyDataParameter := tools.NewStringParameter("history_data",
|
||||
"The table id or the query of the history time series data.")
|
||||
timestampColumnNameParameter := tools.NewStringParameter("timestamp_col",
|
||||
"The name of the time series timestamp column.")
|
||||
dataColumnNameParameter := tools.NewStringParameter("data_col",
|
||||
"The name of the time series data column.")
|
||||
idColumnNameParameter := tools.NewArrayParameterWithDefault("id_cols", []any{},
|
||||
"An array of the time series id column names.",
|
||||
tools.NewStringParameter("id_col", "The name of time series id column."))
|
||||
horizonParameter := tools.NewIntParameterWithDefault("horizon", 10, "The number of forecasting steps.")
|
||||
parameters := tools.Parameters{historyDataParameter,
|
||||
timestampColumnNameParameter, dataColumnNameParameter, idColumnNameParameter, horizonParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
historyData, ok := paramsMap["history_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast history_data parameter %v", paramsMap["history_data"])
|
||||
}
|
||||
timestampCol, ok := paramsMap["timestamp_col"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"])
|
||||
}
|
||||
dataCol, ok := paramsMap["data_col"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast data_col parameter %v", paramsMap["data_col"])
|
||||
}
|
||||
idColsRaw, ok := paramsMap["id_cols"].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast id_cols parameter %v", paramsMap["id_cols"])
|
||||
}
|
||||
var idCols []string
|
||||
for _, v := range idColsRaw {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("id_cols contains non-string value: %v", v)
|
||||
}
|
||||
idCols = append(idCols, s)
|
||||
}
|
||||
horizon, ok := paramsMap["horizon"].(int)
|
||||
if !ok {
|
||||
if h, ok := paramsMap["horizon"].(float64); ok {
|
||||
horizon = int(h)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to cast horizon parameter %v", paramsMap["horizon"])
|
||||
}
|
||||
}
|
||||
|
||||
var historyDataSource string
|
||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||
} else {
|
||||
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
|
||||
}
|
||||
|
||||
idColsArg := ""
|
||||
if len(idCols) > 0 {
|
||||
idColsFormatted := fmt.Sprintf("['%s']", strings.Join(idCols, "', '"))
|
||||
idColsArg = fmt.Sprintf(", id_cols => %s", idColsFormatted)
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf(`SELECT *
|
||||
FROM AI.FORECAST(
|
||||
%s,
|
||||
data_col => '%s',
|
||||
timestamp_col => '%s',
|
||||
horizon => %d%s)`,
|
||||
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
|
||||
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := t.Client.Query(sql)
|
||||
query.Location = t.Client.Location
|
||||
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
var out []any
|
||||
it, err := query.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err = it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for key, value := range row {
|
||||
vMap[key] = value
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
// If the query returned any rows, return them directly.
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryforecast_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQueryForecast(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigquery-forecast
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigqueryforecast.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-forecast",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -105,6 +105,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
tableNameForecast := fmt.Sprintf("`%s.%s.forecast_table_%s`",
|
||||
BigqueryProject,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
@@ -121,6 +126,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
|
||||
defer teardownTable3(t)
|
||||
|
||||
// set up data for forecast tool
|
||||
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
|
||||
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
|
||||
defer teardownTable4(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
|
||||
@@ -163,6 +173,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
|
||||
runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName)
|
||||
runBigQueryForecastToolInvokeTest(t, tableNameForecast)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
@@ -220,6 +231,25 @@ func getBigQueryDataTypeTestInfo(tableName string) (string, string, string, stri
|
||||
return createStatement, insertStatement, toolStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryForecastToolInfo returns statements and params for the forecast tool.
|
||||
func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (ts TIMESTAMP, data FLOAT64, id STRING);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (ts, data, id) VALUES
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?),
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?);`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: "2025-01-01T00:00:00Z"}, {Value: 10.0}, {Value: "a"},
|
||||
{Value: "2025-01-01T01:00:00Z"}, {Value: 11.0}, {Value: "a"},
|
||||
{Value: "2025-01-01T02:00:00Z"}, {Value: 12.0}, {Value: "a"},
|
||||
{Value: "2025-01-01T00:00:00Z"}, {Value: 20.0}, {Value: "b"},
|
||||
{Value: "2025-01-01T01:00:00Z"}, {Value: 21.0}, {Value: "b"},
|
||||
{Value: "2025-01-01T02:00:00Z"}, {Value: 22.0}, {Value: "b"},
|
||||
}
|
||||
return createStatement, insertStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
@@ -322,6 +352,19 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
tools["my-forecast-tool"] = map[string]any{
|
||||
"kind": "bigquery-forecast",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to forecast time series data.",
|
||||
}
|
||||
tools["my-auth-forecast-tool"] = map[string]any{
|
||||
"kind": "bigquery-forecast",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to forecast time series data with auth.",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
tools["my-list-dataset-ids-tool"] = map[string]any{
|
||||
"kind": "bigquery-list-dataset-ids",
|
||||
"source": "my-instance",
|
||||
@@ -666,6 +709,114 @@ func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string)
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
historyDataTable := strings.ReplaceAll(tableName, "`", "")
|
||||
historyDataQuery := fmt.Sprintf("SELECT ts, data, id FROM %s", tableName)
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-forecast-tool without required params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s"}`, historyDataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-forecast-tool with table",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
|
||||
want: `"forecast_timestamp"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-forecast-tool with query and horizon",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "horizon": 5}`, historyDataQuery))),
|
||||
want: `"forecast_timestamp"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-forecast-tool with id_cols",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "id_cols": ["id"]}`, historyDataTable))),
|
||||
want: `"id"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-forecast-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
|
||||
want: `"forecast_timestamp"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-forecast-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("expected %q to contain %q, but it did not", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
|
||||
Reference in New Issue
Block a user