feat(tools/bigquery-forecast): Add bigqueryforecast tool (#1148)

This tool wraps the BigQuery's AI.FORECAST function to do the time
series forecasting.

Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
Haoming Chen
2025-08-18 13:45:54 -07:00
committed by GitHub
parent 3abc9e00b5
commit 2ad0ccf83d
7 changed files with 528 additions and 1 deletions

View File

@@ -44,6 +44,7 @@ import (
// Import tool packages for side effect of registration
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids"

View File

@@ -1210,7 +1210,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"bigquery-database-tools": tools.ToolsetConfig{
Name: "bigquery-database-tools",
ToolNames: []string{"execute_sql", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
ToolNames: []string{"execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
},
},
},

View File

@@ -0,0 +1,49 @@
---
title: "bigquery-forecast"
type: docs
weight: 1
description: >
A "bigquery-forecast" tool forecasts time series data in BigQuery.
aliases:
- /resources/tools/bigquery-forecast
---
## About
A `bigquery-forecast` tool forecasts time series data in BigQuery.
It's compatible with the following sources:
- [bigquery](../../sources/bigquery.md)
`bigquery-forecast` constructs and executes a `SELECT * FROM AI.FORECAST(...)` query based on the provided parameters:
- **history_data** (string, required): This specifies the source of the historical time series data. It can be either a fully qualified BigQuery table ID (e.g., my-project.my_dataset.my_table) or a SQL query that returns the data.
- **timestamp_col** (string, required): The name of the column in your history_data that contains the timestamps.
- **data_col** (string, required): The name of the column in your history_data that contains the numeric values to be forecasted.
- **id_cols** (array of strings, optional): If you are forecasting multiple time series at once (e.g., sales for different products), this parameter takes an array of column names that uniquely identify each series. It defaults to an empty array if not provided.
- **horizon** (integer, optional): The number of future time steps you want to predict. It defaults to 10 if not specified.
## Example
```yaml
tools:
forecast_tool:
kind: bigquery-forecast
source: my-bigquery-source
description: Use this tool to forecast time series data in BigQuery.
```
## Sample Prompt
You can use the following sample prompts to call this tool:
- Can you forecast the history time series data in bigquery table `bqml_tutorial.google_analytic`? Use project_id `myproject`.
- What are the future `total_visits` in bigquery table `bqml_tutorial.google_analytic`?
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "bigquery-forecast". |
| source | string | true | Name of the source the forecast tool should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |

View File

@@ -9,6 +9,11 @@ tools:
source: bigquery-source
description: Use this tool to execute sql statement.
forecast:
kind: bigquery-forecast
source: bigquery-source
description: Use this tool to forecast time series data.
get_dataset_info:
kind: bigquery-get-dataset-info
source: bigquery-source
@@ -32,6 +37,7 @@ tools:
toolsets:
bigquery-database-tools:
- execute_sql
- forecast
- get_dataset_info
- get_table_info
- list_dataset_ids

View File

@@ -0,0 +1,247 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigqueryforecast
import (
"context"
"fmt"
"strings"
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-forecast"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryRestService() *bigqueryrestapi.Service
}
// validate compatible sources are still compatible
var _ compatibleSource = &bigqueryds.Source{}
var compatibleSources = [...]string{bigqueryds.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
historyDataParameter := tools.NewStringParameter("history_data",
"The table id or the query of the history time series data.")
timestampColumnNameParameter := tools.NewStringParameter("timestamp_col",
"The name of the time series timestamp column.")
dataColumnNameParameter := tools.NewStringParameter("data_col",
"The name of the time series data column.")
idColumnNameParameter := tools.NewArrayParameterWithDefault("id_cols", []any{},
"An array of the time series id column names.",
tools.NewStringParameter("id_col", "The name of time series id column."))
horizonParameter := tools.NewIntParameterWithDefault("horizon", 10, "The number of forecasting steps.")
parameters := tools.Parameters{historyDataParameter,
timestampColumnNameParameter, dataColumnNameParameter, idColumnNameParameter, horizonParameter}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast history_data parameter %v", paramsMap["history_data"])
}
timestampCol, ok := paramsMap["timestamp_col"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"])
}
dataCol, ok := paramsMap["data_col"].(string)
if !ok {
return nil, fmt.Errorf("unable to cast data_col parameter %v", paramsMap["data_col"])
}
idColsRaw, ok := paramsMap["id_cols"].([]any)
if !ok {
return nil, fmt.Errorf("unable to cast id_cols parameter %v", paramsMap["id_cols"])
}
var idCols []string
for _, v := range idColsRaw {
s, ok := v.(string)
if !ok {
return nil, fmt.Errorf("id_cols contains non-string value: %v", v)
}
idCols = append(idCols, s)
}
horizon, ok := paramsMap["horizon"].(int)
if !ok {
if h, ok := paramsMap["horizon"].(float64); ok {
horizon = int(h)
} else {
return nil, fmt.Errorf("unable to cast horizon parameter %v", paramsMap["horizon"])
}
}
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
}
idColsArg := ""
if len(idCols) > 0 {
idColsFormatted := fmt.Sprintf("['%s']", strings.Join(idCols, "', '"))
idColsArg = fmt.Sprintf(", id_cols => %s", idColsFormatted)
}
sql := fmt.Sprintf(`SELECT *
FROM AI.FORECAST(
%s,
data_col => '%s',
timestamp_col => '%s',
horizon => %d%s)`,
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
// JobStatistics.QueryStatistics.StatementType
query := t.Client.Query(sql)
query.Location = t.Client.Location
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
return "The query returned 0 rows.", nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -0,0 +1,73 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigqueryforecast_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
)
func TestParseFromYamlBigQueryForecast(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: bigquery-forecast
source: my-instance
description: some description
`,
want: server.ToolConfigs{
"example_tool": bigqueryforecast.Config{
Name: "example_tool",
Kind: "bigquery-forecast",
Source: "my-instance",
Description: "some description",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -105,6 +105,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
datasetName,
strings.ReplaceAll(uuid.New().String(), "-", ""),
)
tableNameForecast := fmt.Sprintf("`%s.%s.forecast_table_%s`",
BigqueryProject,
datasetName,
strings.ReplaceAll(uuid.New().String(), "-", ""),
)
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
@@ -121,6 +126,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
defer teardownTable3(t)
// set up data for forecast tool
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
defer teardownTable4(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
@@ -163,6 +173,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName)
runBigQueryForecastToolInvokeTest(t, tableNameForecast)
runBigQueryDataTypeTests(t)
runBigQueryListDatasetToolInvokeTest(t, datasetName)
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
@@ -220,6 +231,25 @@ func getBigQueryDataTypeTestInfo(tableName string) (string, string, string, stri
return createStatement, insertStatement, toolStatement, arrayToolStatement, params
}
// getBigQueryForecastToolInfo returns statements and params for the forecast tool.
func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) {
createStatement := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (ts TIMESTAMP, data FLOAT64, id STRING);`, tableName)
insertStatement := fmt.Sprintf(`
INSERT INTO %s (ts, data, id) VALUES
(?, ?, ?), (?, ?, ?), (?, ?, ?),
(?, ?, ?), (?, ?, ?), (?, ?, ?);`, tableName)
params := []bigqueryapi.QueryParameter{
{Value: "2025-01-01T00:00:00Z"}, {Value: 10.0}, {Value: "a"},
{Value: "2025-01-01T01:00:00Z"}, {Value: 11.0}, {Value: "a"},
{Value: "2025-01-01T02:00:00Z"}, {Value: 12.0}, {Value: "a"},
{Value: "2025-01-01T00:00:00Z"}, {Value: 20.0}, {Value: "b"},
{Value: "2025-01-01T01:00:00Z"}, {Value: 21.0}, {Value: "b"},
{Value: "2025-01-01T02:00:00Z"}, {Value: 22.0}, {Value: "b"},
}
return createStatement, insertStatement, params
}
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
func getBigQueryTmplToolStatement() (string, string) {
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
@@ -322,6 +352,19 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
"my-google-auth",
},
}
tools["my-forecast-tool"] = map[string]any{
"kind": "bigquery-forecast",
"source": "my-instance",
"description": "Tool to forecast time series data.",
}
tools["my-auth-forecast-tool"] = map[string]any{
"kind": "bigquery-forecast",
"source": "my-instance",
"description": "Tool to forecast time series data with auth.",
"authRequired": []string{
"my-google-auth",
},
}
tools["my-list-dataset-ids-tool"] = map[string]any{
"kind": "bigquery-list-dataset-ids",
"source": "my-instance",
@@ -666,6 +709,114 @@ func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string)
}
}
func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
t.Fatalf("error getting Google ID token: %s", err)
}
historyDataTable := strings.ReplaceAll(tableName, "`", "")
historyDataQuery := fmt.Sprintf("SELECT ts, data, id FROM %s", tableName)
invokeTcs := []struct {
name string
api string
requestHeader map[string]string
requestBody io.Reader
want string
isErr bool
}{
{
name: "invoke my-forecast-tool without required params",
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s"}`, historyDataTable))),
isErr: true,
},
{
name: "invoke my-forecast-tool with table",
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
want: `"forecast_timestamp"`,
isErr: false,
},
{
name: "invoke my-forecast-tool with query and horizon",
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "horizon": 5}`, historyDataQuery))),
want: `"forecast_timestamp"`,
isErr: false,
},
{
name: "invoke my-forecast-tool with id_cols",
api: "http://127.0.0.1:5000/api/tool/my-forecast-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data", "id_cols": ["id"]}`, historyDataTable))),
want: `"id"`,
isErr: false,
},
{
name: "invoke my-auth-forecast-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
want: `"forecast_timestamp"`,
isErr: false,
},
{
name: "invoke my-auth-forecast-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-forecast-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"history_data": "%s", "timestamp_col": "ts", "data_col": "data"}`, historyDataTable))),
isErr: true,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if !strings.Contains(got, tc.want) {
t.Fatalf("expected %q to contain %q, but it did not", got, tc.want)
}
})
}
}
func runBigQueryDataTypeTests(t *testing.T) {
// Test tool invoke endpoint
invokeTcs := []struct {