mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 00:49:08 -05:00
feat(tools/bigquery-ask-data-insights): add bigquery ask-data-insights tool (#932)
1. Add ask-data-insights tool based on conversational analytic API. 2. Add tokenSource for ask-data-insights tool, it uses access token instead of client or restService. 3. Add a max row count to source, currently fixed to 50 and used only for ask-data-insights tool. Later we may make it available for user to make change and apply to bigquery-execute-sql and bigquery-sql to avoid return too many data by accident. --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -43,6 +43,7 @@ import (
|
||||
|
||||
// Import tool packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo"
|
||||
|
||||
@@ -1279,7 +1279,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"bigquery-database-tools": tools.ToolsetConfig{
|
||||
Name: "bigquery-database-tools",
|
||||
ToolNames: []string{"execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
ToolNames: []string{"ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -53,8 +53,11 @@ See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for detail
|
||||
* **BigQuery User** (`roles/bigquery.user`) to execute queries and view metadata.
|
||||
* **BigQuery Metadata Viewer** (`roles/bigquery.metadataViewer`) to view all datasets.
|
||||
* **BigQuery Data Editor** (`roles/bigquery.dataEditor`) to create or modify datasets and tables.
|
||||
* **Gemini for Google Cloud** (`roles/cloudaicompanion.user`) to use the conversational analytics API.
|
||||
* **Tools:**
|
||||
* `ask_data_insights`: Use this tool to perform data analysis, get insights, or answer complex questions about the contents of specific BigQuery tables. For more information on required roles, API setup, and IAM configuration, see the setup and authentication section of the [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview).
|
||||
* `execute_sql`: Executes a SQL statement.
|
||||
* `forecast`: Use this tool to forecast time series data.
|
||||
* `get_dataset_info`: Gets dataset metadata.
|
||||
* `get_table_info`: Gets table metadata.
|
||||
* `list_dataset_ids`: Lists datasets.
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
---
|
||||
title: "bigquery-conversational-analytics"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "bigquery-conversational-analytics" tool allows conversational interaction with a BigQuery source.
|
||||
aliases:
|
||||
- /resources/tools/bigquery-conversational-analytics
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `bigquery-conversational-analytics` tool allows you to ask questions about your data in natural language.
|
||||
|
||||
This function takes a user's question (which can include conversational history for context)
|
||||
and references to specific BigQuery tables, and sends them to a stateless conversational API.
|
||||
|
||||
The API uses a GenAI agent to understand the question, generate and execute SQL queries
|
||||
and Python code, and formulate an answer. This function returns a detailed, sequential
|
||||
log of this entire process, which includes any generated SQL or Python code, the data
|
||||
retrieved, and the final text answer.
|
||||
|
||||
**Note**: This tool requires additional setup in your project. Please refer to the
|
||||
official [Conversational Analytics API documentation](https://cloud.google.com/gemini/docs/conversational-analytics-api/overview)
|
||||
for instructions.
|
||||
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
|
||||
The tool takes the following input parameters:
|
||||
|
||||
* `user_query_with_context`: The user's question, potentially including conversation history and system instructions for context.
|
||||
* `table_references`: A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example: `'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
ask_data_insights:
|
||||
kind: bigquery-conversational-analytics
|
||||
source: my-bigquery-source
|
||||
description: |
|
||||
Use this tool to perform data analysis, get insights, or answer complex
|
||||
questions about the contents of specific BigQuery tables.
|
||||
```
|
||||
|
||||
## Reference
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery-conversational-analytics". |
|
||||
| source | string | true | Name of the source for chat. |
|
||||
| description | string | true | Description of the tool
|
||||
that is passed to the LLM. |
|
||||
@@ -4,6 +4,14 @@ sources:
|
||||
project: ${BIGQUERY_PROJECT}
|
||||
|
||||
tools:
|
||||
ask_data_insights:
|
||||
kind: bigquery-conversational-analytics
|
||||
source: bigquery-source
|
||||
description: |
|
||||
Use this tool to perform data analysis, get insights,
|
||||
or answer complex questions about the contents of specific
|
||||
BigQuery tables.
|
||||
|
||||
execute_sql:
|
||||
kind: bigquery-execute-sql
|
||||
source: bigquery-source
|
||||
@@ -36,6 +44,7 @@ tools:
|
||||
|
||||
toolsets:
|
||||
bigquery-database-tools:
|
||||
- ask_data_insights
|
||||
- execute_sql
|
||||
- forecast
|
||||
- get_dataset_info
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/option"
|
||||
@@ -62,17 +63,18 @@ func (r Config) SourceConfigKind() string {
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, restService, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
client, restService, tokenSource, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
Location: r.Location,
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -82,11 +84,12 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// BigQuery Google SQL struct with client
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
Location string `yaml:"location"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -102,38 +105,46 @@ func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
|
||||
return s.RestService
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryTokenSource() oauth2.TokenSource {
|
||||
return s.TokenSource
|
||||
}
|
||||
|
||||
func (s *Source) GetMaxQueryResultRows() int {
|
||||
return s.MaxQueryResultRows
|
||||
}
|
||||
|
||||
func initBigQueryConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
name string,
|
||||
project string,
|
||||
location string,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
}
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Initialize the high-level BigQuery client
|
||||
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
}
|
||||
client.Location = location
|
||||
|
||||
// Initialize the low-level BigQuery REST service using the same credentials
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
}
|
||||
|
||||
return client, restService, nil
|
||||
return client, restService, cred.TokenSource, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,523 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryconversationalanalytics
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-conversational-analytics"
|
||||
|
||||
const instructions = `**INSTRUCTIONS - FOLLOW THESE RULES:**
|
||||
1. **CONTENT:** Your answer should present the supporting data and then provide a conclusion based on that data.
|
||||
2. **OUTPUT FORMAT:** Your entire response MUST be in plain text format ONLY.
|
||||
3. **NO CHARTS:** You are STRICTLY FORBIDDEN from generating any charts, graphs, images, or any other form of visualization.`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryTokenSource() oauth2.TokenSource
|
||||
GetMaxQueryResultRows() int
|
||||
}
|
||||
|
||||
type BQTableReference struct {
|
||||
ProjectID string `json:"projectId"`
|
||||
DatasetID string `json:"datasetId"`
|
||||
TableID string `json:"tableId"`
|
||||
}
|
||||
|
||||
// Structs for building the JSON payload
|
||||
type UserMessage struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
type Message struct {
|
||||
UserMessage UserMessage `json:"userMessage"`
|
||||
}
|
||||
type BQDatasource struct {
|
||||
TableReferences []BQTableReference `json:"tableReferences"`
|
||||
}
|
||||
type DatasourceReferences struct {
|
||||
BQ BQDatasource `json:"bq"`
|
||||
}
|
||||
type ImageOptions struct {
|
||||
NoImage map[string]any `json:"noImage"`
|
||||
}
|
||||
type ChartOptions struct {
|
||||
Image ImageOptions `json:"image"`
|
||||
}
|
||||
type Options struct {
|
||||
Chart ChartOptions `json:"chart"`
|
||||
}
|
||||
type InlineContext struct {
|
||||
DatasourceReferences DatasourceReferences `json:"datasourceReferences"`
|
||||
Options Options `json:"options"`
|
||||
}
|
||||
|
||||
type CAPayload struct {
|
||||
Project string `json:"project"`
|
||||
Messages []Message `json:"messages"`
|
||||
InlineContext InlineContext `json:"inlineContext"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
userQueryParameter := tools.NewStringParameter("user_query_with_context", "The user's question, potentially including conversation history and system instructions for context.")
|
||||
tableRefsParameter := tools.NewStringParameter("table_references", `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`)
|
||||
|
||||
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
TokenSource: s.BigQueryTokenSource(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Client *bigqueryapi.Client
|
||||
TokenSource oauth2.TokenSource
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
MaxQueryResultRows int
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
// Get credentials for the API call
|
||||
if t.TokenSource == nil {
|
||||
return nil, fmt.Errorf("authentication error: found credentials but they are missing a valid token source")
|
||||
}
|
||||
|
||||
token, err := t.TokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from credentials: %w", err)
|
||||
}
|
||||
|
||||
// Extract parameters from the map
|
||||
mapParams := params.AsMap()
|
||||
userQuery, _ := mapParams["user_query_with_context"].(string)
|
||||
|
||||
finalQueryText := fmt.Sprintf("%s\n**User Query and Context:**\n%s", instructions, userQuery)
|
||||
|
||||
tableRefsJSON, _ := mapParams["table_references"].(string)
|
||||
var tableRefs []BQTableReference
|
||||
if tableRefsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(tableRefsJSON), &tableRefs); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse 'table_references' JSON string: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Client.Project()
|
||||
location := t.Client.Location
|
||||
if location == "" {
|
||||
location = "us"
|
||||
}
|
||||
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1alpha/projects/%s/locations/%s:chat", projectID, location)
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", token.AccessToken),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload := CAPayload{
|
||||
Project: fmt.Sprintf("projects/%s", projectID),
|
||||
Messages: []Message{{UserMessage: UserMessage{Text: finalQueryText}}},
|
||||
InlineContext: InlineContext{
|
||||
DatasourceReferences: DatasourceReferences{
|
||||
BQ: BQDatasource{TableReferences: tableRefs},
|
||||
},
|
||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||
},
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// StreamMessage represents a single message object from the streaming API response.
|
||||
type StreamMessage struct {
|
||||
SystemMessage *SystemMessage `json:"systemMessage,omitempty"`
|
||||
Error *ErrorResponse `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// SystemMessage contains different types of system-generated content.
|
||||
type SystemMessage struct {
|
||||
Text *TextResponse `json:"text,omitempty"`
|
||||
Schema *SchemaResponse `json:"schema,omitempty"`
|
||||
Data *DataResponse `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// TextResponse contains textual parts of a message.
|
||||
type TextResponse struct {
|
||||
Parts []string `json:"parts"`
|
||||
}
|
||||
|
||||
// SchemaResponse contains schema-related information.
|
||||
type SchemaResponse struct {
|
||||
Query *SchemaQuery `json:"query,omitempty"`
|
||||
Result *SchemaResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// SchemaQuery holds the question that prompted a schema lookup.
|
||||
type SchemaQuery struct {
|
||||
Question string `json:"question"`
|
||||
}
|
||||
|
||||
// SchemaResult contains the datasources with their schemas.
|
||||
type SchemaResult struct {
|
||||
Datasources []Datasource `json:"datasources"`
|
||||
}
|
||||
|
||||
// Datasource represents a data source with its reference and schema.
|
||||
type Datasource struct {
|
||||
BigQueryTableReference *BQTableReference `json:"bigqueryTableReference,omitempty"`
|
||||
Schema *BQSchema `json:"schema,omitempty"`
|
||||
}
|
||||
|
||||
// BQSchema defines the structure of a BigQuery table.
|
||||
type BQSchema struct {
|
||||
Fields []BQField `json:"fields"`
|
||||
}
|
||||
|
||||
// BQField describes a single column in a BigQuery table.
|
||||
type BQField struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
// DataResponse contains data-related information, like queries and results.
|
||||
type DataResponse struct {
|
||||
Query *DataQuery `json:"query,omitempty"`
|
||||
GeneratedSQL string `json:"generatedSql,omitempty"`
|
||||
Result *DataResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// DataQuery holds information about a data retrieval query.
|
||||
type DataQuery struct {
|
||||
Name string `json:"name"`
|
||||
Question string `json:"question"`
|
||||
}
|
||||
|
||||
// DataResult contains the schema and rows of a query result.
|
||||
type DataResult struct {
|
||||
Schema BQSchema `json:"schema"`
|
||||
Data []map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error message from the API.
|
||||
type ErrorResponse struct {
|
||||
Code float64 `json:"code"` // JSON numbers are float64 by default
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func getStream(url string, payload CAPayload, headers map[string]string, maxRows int) (string, error) {
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("API returned non-200 status: %d %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
decoder := json.NewDecoder(resp.Body)
|
||||
|
||||
// The response is a JSON array, so we read the opening bracket.
|
||||
if _, err := decoder.Token(); err != nil {
|
||||
if err == io.EOF {
|
||||
return "", nil // Empty response is valid
|
||||
}
|
||||
return "", fmt.Errorf("error reading start of json array: %w", err)
|
||||
}
|
||||
|
||||
for decoder.More() {
|
||||
var msg StreamMessage
|
||||
if err := decoder.Decode(&msg); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", fmt.Errorf("error decoding stream message: %w", err)
|
||||
}
|
||||
|
||||
var newMessage map[string]any
|
||||
if msg.SystemMessage != nil {
|
||||
if msg.SystemMessage.Text != nil {
|
||||
newMessage = handleTextResponse(msg.SystemMessage.Text)
|
||||
} else if msg.SystemMessage.Schema != nil {
|
||||
newMessage = handleSchemaResponse(msg.SystemMessage.Schema)
|
||||
} else if msg.SystemMessage.Data != nil {
|
||||
newMessage = handleDataResponse(msg.SystemMessage.Data, maxRows)
|
||||
}
|
||||
} else if msg.Error != nil {
|
||||
newMessage = handleError(msg.Error)
|
||||
}
|
||||
messages = appendMessage(messages, newMessage)
|
||||
}
|
||||
|
||||
var acc strings.Builder
|
||||
for i, msg := range messages {
|
||||
jsonBytes, err := json.MarshalIndent(msg, "", " ")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error marshalling message: %w", err)
|
||||
}
|
||||
acc.Write(jsonBytes)
|
||||
if i < len(messages)-1 {
|
||||
acc.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return acc.String(), nil
|
||||
}
|
||||
|
||||
func formatBqTableRef(tableRef *BQTableReference) string {
|
||||
return fmt.Sprintf("%s.%s.%s", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
}
|
||||
|
||||
func formatSchemaAsDict(data *BQSchema) map[string]any {
|
||||
headers := []string{"Column", "Type", "Description", "Mode"}
|
||||
if data == nil {
|
||||
return map[string]any{"headers": headers, "rows": []any{}}
|
||||
}
|
||||
|
||||
var rows [][]any
|
||||
for _, field := range data.Fields {
|
||||
rows = append(rows, []any{field.Name, field.Type, field.Description, field.Mode})
|
||||
}
|
||||
return map[string]any{"headers": headers, "rows": rows}
|
||||
}
|
||||
|
||||
func formatDatasourceAsDict(datasource *Datasource) map[string]any {
|
||||
var sourceName string
|
||||
if datasource.BigQueryTableReference != nil {
|
||||
sourceName = formatBqTableRef(datasource.BigQueryTableReference)
|
||||
}
|
||||
|
||||
var schema map[string]any
|
||||
if datasource.Schema != nil {
|
||||
schema = formatSchemaAsDict(datasource.Schema)
|
||||
}
|
||||
|
||||
return map[string]any{"source_name": sourceName, "schema": schema}
|
||||
}
|
||||
|
||||
func handleTextResponse(resp *TextResponse) map[string]any {
|
||||
return map[string]any{"Answer": strings.Join(resp.Parts, "")}
|
||||
}
|
||||
|
||||
func handleSchemaResponse(resp *SchemaResponse) map[string]any {
|
||||
if resp.Query != nil {
|
||||
return map[string]any{"Question": resp.Query.Question}
|
||||
}
|
||||
if resp.Result != nil {
|
||||
var formattedSources []map[string]any
|
||||
for _, ds := range resp.Result.Datasources {
|
||||
formattedSources = append(formattedSources, formatDatasourceAsDict(&ds))
|
||||
}
|
||||
return map[string]any{"Schema Resolved": formattedSources}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleDataResponse(resp *DataResponse, maxRows int) map[string]any {
|
||||
if resp.Query != nil {
|
||||
return map[string]any{
|
||||
"Retrieval Query": map[string]any{
|
||||
"Query Name": resp.Query.Name,
|
||||
"Question": resp.Query.Question,
|
||||
},
|
||||
}
|
||||
}
|
||||
if resp.GeneratedSQL != "" {
|
||||
return map[string]any{"SQL Generated": resp.GeneratedSQL}
|
||||
}
|
||||
if resp.Result != nil {
|
||||
var headers []string
|
||||
for _, f := range resp.Result.Schema.Fields {
|
||||
headers = append(headers, f.Name)
|
||||
}
|
||||
|
||||
totalRows := len(resp.Result.Data)
|
||||
var compactRows [][]any
|
||||
numRowsToDisplay := totalRows
|
||||
if numRowsToDisplay > maxRows {
|
||||
numRowsToDisplay = maxRows
|
||||
}
|
||||
|
||||
for _, rowVal := range resp.Result.Data[:numRowsToDisplay] {
|
||||
var rowValues []any
|
||||
for _, header := range headers {
|
||||
rowValues = append(rowValues, rowVal[header])
|
||||
}
|
||||
compactRows = append(compactRows, rowValues)
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf("Showing all %d rows.", totalRows)
|
||||
if totalRows > maxRows {
|
||||
summary = fmt.Sprintf("Showing the first %d of %d total rows.", numRowsToDisplay, totalRows)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"Data Retrieved": map[string]any{
|
||||
"headers": headers,
|
||||
"rows": compactRows,
|
||||
"summary": summary,
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleError(resp *ErrorResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"Error": map[string]any{
|
||||
"Code": int(resp.Code),
|
||||
"Message": resp.Message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func appendMessage(messages []map[string]any, newMessage map[string]any) []map[string]any {
|
||||
if newMessage == nil {
|
||||
return messages
|
||||
}
|
||||
if len(messages) > 0 {
|
||||
if _, ok := messages[len(messages)-1]["Data Retrieved"]; ok {
|
||||
messages = messages[:len(messages)-1]
|
||||
}
|
||||
}
|
||||
return append(messages, newMessage)
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryconversationalanalytics_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQueryConversationalAnalytics(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigquery-conversational-analytics
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigqueryconversationalanalytics.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-conversational-analytics",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -279,6 +279,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
case "string":
|
||||
|
||||
@@ -158,6 +158,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":"
|
||||
tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
ddlWant := `"Query executed successfully and returned no content."`
|
||||
dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer`
|
||||
// Partial message; the full error message is too long.
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"final query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
createColArray := `["id INT64", "name STRING", "age INT64"]`
|
||||
@@ -182,6 +183,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
runBigQueryListTableIdsToolInvokeTest(t, datasetName, tableName)
|
||||
runBigQueryGetTableInfoToolInvokeTest(t, datasetName, tableName, tableInfoWant)
|
||||
runBigQueryConversationalAnalyticsInvokeTest(t, datasetName, tableName, dataInsightsWant)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
@@ -420,6 +422,19 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
tools["my-conversational-analytics-tool"] = map[string]any{
|
||||
"kind": "bigquery-conversational-analytics",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to ask BigQuery conversational analytics",
|
||||
}
|
||||
tools["my-auth-conversational-analytics-tool"] = map[string]any{
|
||||
"kind": "bigquery-conversational-analytics",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to ask BigQuery conversational analytics",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
@@ -1349,3 +1364,94 @@ func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tableName, dataInsightsWant string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
tableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, datasetName, tableName)
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-conversational-analytics-tool successfully",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-conversational-analytics-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(
|
||||
`{"user_query_with_context": "What are the names in the table?", "table_references": %q}`,
|
||||
tableRefsJSON,
|
||||
))),
|
||||
want: dataInsightsWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-conversational-analytics-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-conversational-analytics-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(
|
||||
`{"user_query_with_context": "What are the names in the table?", "table_references": %q}`,
|
||||
tableRefsJSON,
|
||||
))),
|
||||
want: dataInsightsWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-conversational-analytics-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-conversational-analytics-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"user_query_with_context": "What are the names in the table?"}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
wantPattern := regexp.MustCompile(tc.want)
|
||||
if !wantPattern.MatchString(got) {
|
||||
t.Fatalf("response did not match the expected pattern.\nFull response:\n%s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user