mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-13 17:38:10 -05:00
Compare commits
1 Commits
guide
...
feat/bigqu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00b5a9c18e |
@@ -52,7 +52,7 @@ var _ sources.SourceConfig = Config{}
|
||||
|
||||
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||
|
||||
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
||||
type BigQuerySessionProvider func(ctx context.Context, toolName string) (*Session, error)
|
||||
|
||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||
|
||||
@@ -287,7 +287,7 @@ func (s *Source) BigQuerySession() BigQuerySessionProvider {
|
||||
}
|
||||
|
||||
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||
return func(ctx context.Context) (*Session, error) {
|
||||
return func(ctx context.Context, toolName string) (*Session, error) {
|
||||
if s.WriteMode != WriteModeProtected {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -300,6 +300,8 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
||||
}
|
||||
|
||||
labels := map[string]string{"genai-toolbox-tool": toolName}
|
||||
|
||||
if s.Session != nil {
|
||||
// Absolute 7-day lifetime check.
|
||||
const sessionMaxLifetime = 7 * 24 * time.Hour
|
||||
@@ -310,6 +312,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||
} else {
|
||||
job := &bigqueryrestapi.Job{
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
Labels: labels,
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: "SELECT 1",
|
||||
@@ -337,6 +340,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||
Location: s.Location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
Labels: labels,
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: "SELECT 1",
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,186 +12,97 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigquery_test
|
||||
package bigquery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"google.golang.org/api/option"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "",
|
||||
WriteMode: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "all fields specified",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: asia
|
||||
writeMode: blocked
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "asia",
|
||||
WriteMode: "blocked",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with allowed datasets example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
allowedDatasets:
|
||||
- my_dataset
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
AllowedDatasets: []string{"my_dataset"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with service account impersonation example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
func TestNewBigQuerySessionProvider(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger, err := log.NewStdLogger(&buf, &buf, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create logger: %v", err)
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
ctx := util.WithLogger(context.Background(), logger)
|
||||
projectID := "test-project"
|
||||
location := "us"
|
||||
toolName := "test-tool"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/projects/test-project/jobs" {
|
||||
t.Errorf("expected to request '/projects/test-project/jobs', got: %s", r.URL.Path)
|
||||
}
|
||||
var job bigqueryrestapi.Job
|
||||
if err := json.NewDecoder(r.Body).Decode(&job); err != nil {
|
||||
t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
expectedLabels := map[string]string{"genai-toolbox-tool": toolName}
|
||||
if !reflect.DeepEqual(job.Configuration.Labels, expectedLabels) {
|
||||
t.Errorf("expected labels %v, got %v", expectedLabels, job.Configuration.Labels)
|
||||
}
|
||||
|
||||
if !job.Configuration.Query.CreateSession {
|
||||
t.Errorf("expected CreateSession to be true")
|
||||
}
|
||||
|
||||
// Send back a dummy response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(&bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: projectID,
|
||||
JobId: "job_123",
|
||||
Location: location,
|
||||
},
|
||||
Status: &bigqueryrestapi.JobStatus{
|
||||
State: "DONE",
|
||||
},
|
||||
Statistics: &bigqueryrestapi.JobStatistics{
|
||||
SessionInfo: &bigqueryrestapi.SessionInfo{
|
||||
SessionId: "session_123",
|
||||
},
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
DestinationTable: &bigqueryrestapi.TableReference{
|
||||
ProjectId: projectID,
|
||||
DatasetId: "dataset_123",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(server.URL), option.WithHTTPClient(http.DefaultClient))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test service: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
s := &Source{
|
||||
Config: Config{
|
||||
Project: projectID,
|
||||
Location: location,
|
||||
WriteMode: WriteModeProtected,
|
||||
},
|
||||
RestService: restService,
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: bigquery\n 3 | location: us\n 4 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
location: us
|
||||
`,
|
||||
err: "unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
sessionProvider := s.newBigQuerySessionProvider()
|
||||
_, err = sessionProvider(ctx, toolName)
|
||||
if err != nil {
|
||||
t.Fatalf("sessionProvider failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,7 +231,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -240,7 +240,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -289,10 +289,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
)
|
||||
|
||||
createModelQuery := bqClient.Query(createModelSQL)
|
||||
createModelQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -332,6 +333,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
|
||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||
getInsightsQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||
|
||||
job, err := getInsightsQuery.Run(ctx)
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
||||
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (*bigqueryrestapi.Job, error) {
|
||||
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty, toolName string) (*bigqueryrestapi.Job, error) {
|
||||
useLegacySql := false
|
||||
|
||||
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
||||
@@ -40,6 +40,7 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
||||
Location: location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
Labels: getLabels(toolName),
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: sql,
|
||||
@@ -57,6 +58,10 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
||||
return insertResponse, nil
|
||||
}
|
||||
|
||||
func getLabels(toolName string) map[string]string {
|
||||
return map[string]string{"genai-toolbox-tool": toolName}
|
||||
}
|
||||
|
||||
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
|
||||
90
internal/tools/bigquery/bigquerycommon/util_test.go
Normal file
90
internal/tools/bigquery/bigquerycommon/util_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigquerycommon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/api/option"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
|
||||
func TestGetLabels(t *testing.T) {
|
||||
toolName := "test-tool"
|
||||
expected := map[string]string{"genai-toolbox-tool": toolName}
|
||||
actual := getLabels(toolName)
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("getLabels() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDryRunQuery(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
projectID := "test-project"
|
||||
location := "us"
|
||||
sql := "SELECT 1"
|
||||
toolName := "test-tool"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/projects/test-project/jobs" {
|
||||
t.Errorf("expected to request '/projects/test-project/jobs', got: %s", r.URL.Path)
|
||||
}
|
||||
var job bigqueryrestapi.Job
|
||||
if err := json.NewDecoder(r.Body).Decode(&job); err != nil {
|
||||
t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
expectedLabels := map[string]string{"genai-toolbox-tool": toolName}
|
||||
if !reflect.DeepEqual(job.Configuration.Labels, expectedLabels) {
|
||||
t.Errorf("expected labels %v, got %v", expectedLabels, job.Configuration.Labels)
|
||||
}
|
||||
|
||||
if !job.Configuration.DryRun {
|
||||
t.Errorf("expected DryRun to be true")
|
||||
}
|
||||
|
||||
// Send back a dummy response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(&bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: projectID,
|
||||
JobId: "job_123",
|
||||
Location: location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: sql,
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(server.URL), option.WithHTTPClient(http.DefaultClient))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test service: %v", err)
|
||||
}
|
||||
|
||||
_, err = DryRunQuery(ctx, restService, projectID, location, sql, nil, nil, toolName)
|
||||
if err != nil {
|
||||
t.Fatalf("DryRunQuery failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -99,10 +99,11 @@ type InlineContext struct {
|
||||
}
|
||||
|
||||
type CAPayload struct {
|
||||
Project string `json:"project"`
|
||||
Messages []Message `json:"messages"`
|
||||
InlineContext InlineContext `json:"inlineContext"`
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
Project string `json:"project"`
|
||||
Messages []Message `json:"messages"`
|
||||
InlineContext InlineContext `json:"inlineContext"`
|
||||
ClientIdEnum string `json:"clientIdEnum"`
|
||||
JobLabels map[string]string `json:"jobLabels,omitempty"`
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -276,6 +277,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||
},
|
||||
ClientIdEnum: "GENAI_TOOLBOX",
|
||||
JobLabels: map[string]string{"genai-toolbox-tool": kind},
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
|
||||
@@ -205,7 +205,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
var session *bigqueryds.Session
|
||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||
session, err = t.SessionProvider(ctx)
|
||||
session, err = t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
}
|
||||
@@ -214,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
}
|
||||
}
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -303,6 +303,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||
|
||||
query.ConnectionProperties = connProps
|
||||
|
||||
|
||||
@@ -1,73 +1,111 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryexecutesql_test
|
||||
package bigqueryexecutesql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"cloud.google.com/go/bigquery"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQueryExecuteSql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
const (
|
||||
basicYAML = `
|
||||
name: bigquery-execute-sql-tool
|
||||
kind: bigquery-execute-sql
|
||||
source: bq
|
||||
description: test
|
||||
authRequired:
|
||||
- gcp
|
||||
`
|
||||
)
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
isError bool
|
||||
want *Config
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigquery-execute-sql
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigqueryexecutesql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-execute-sql",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
name: "basic example",
|
||||
input: basicYAML,
|
||||
want: &Config{
|
||||
Name: "bigquery-execute-sql-tool",
|
||||
Kind: "bigquery-execute-sql",
|
||||
Source: "bq",
|
||||
Description: "test",
|
||||
AuthRequired: []string{"gcp"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var got Config
|
||||
err := yaml.Unmarshal([]byte(tc.input), &got)
|
||||
if tc.isError {
|
||||
if err == nil {
|
||||
t.Errorf("yaml.Unmarshal got nil error, want error")
|
||||
}
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
if err != nil {
|
||||
t.Fatalf("yaml.Unmarshal got unexpected error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, &got); diff != "" {
|
||||
t.Errorf("yaml.Unmarshal() returned diff (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type mockQuery struct {
|
||||
labels map[string]string
|
||||
}
|
||||
|
||||
func (q *mockQuery) Run(ctx context.Context) (*bigquery.Job, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (q *mockQuery) Read(ctx context.Context) (*bigquery.RowIterator, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockBigQueryClient struct {
|
||||
*bigquery.Client
|
||||
t *testing.T
|
||||
projectID string
|
||||
location string
|
||||
}
|
||||
|
||||
func (c *mockBigQueryClient) Query(sql string) *bigquery.Query {
|
||||
q := &bigquery.Query{}
|
||||
return q
|
||||
}
|
||||
|
||||
func (c *mockBigQueryClient) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockBigQueryClient) Project() string {
|
||||
return c.projectID
|
||||
}
|
||||
|
||||
func (c *mockBigQueryClient) Location() string {
|
||||
return c.location
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -218,7 +218,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
@@ -279,7 +279,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
session, err := t.SessionProvider(ctx)
|
||||
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||
session, err := t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
|
||||
@@ -230,10 +230,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
query := bqClient.Query(newStatement)
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = bqClient.Location
|
||||
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||
|
||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||
if t.SessionProvider != nil {
|
||||
session, err := t.SessionProvider(ctx)
|
||||
session, err := t.SessionProvider(ctx, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
@@ -243,7 +244,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
||||
}
|
||||
}
|
||||
query.ConnectionProperties = connProps
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps, kind)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user