mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-11 23:55:07 -05:00
Compare commits
1 Commits
adk-python
...
feat/bigqu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00b5a9c18e |
@@ -52,7 +52,7 @@ var _ sources.SourceConfig = Config{}
|
|||||||
|
|
||||||
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
|
|
||||||
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
type BigQuerySessionProvider func(ctx context.Context, toolName string) (*Session, error)
|
||||||
|
|
||||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||||
|
|
||||||
@@ -287,7 +287,7 @@ func (s *Source) BigQuerySession() BigQuerySessionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||||
return func(ctx context.Context) (*Session, error) {
|
return func(ctx context.Context, toolName string) (*Session, error) {
|
||||||
if s.WriteMode != WriteModeProtected {
|
if s.WriteMode != WriteModeProtected {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -300,6 +300,8 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
labels := map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
|
||||||
if s.Session != nil {
|
if s.Session != nil {
|
||||||
// Absolute 7-day lifetime check.
|
// Absolute 7-day lifetime check.
|
||||||
const sessionMaxLifetime = 7 * 24 * time.Hour
|
const sessionMaxLifetime = 7 * 24 * time.Hour
|
||||||
@@ -310,6 +312,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
} else {
|
} else {
|
||||||
job := &bigqueryrestapi.Job{
|
job := &bigqueryrestapi.Job{
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: labels,
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: "SELECT 1",
|
Query: "SELECT 1",
|
||||||
@@ -337,6 +340,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
Location: s.Location,
|
Location: s.Location,
|
||||||
},
|
},
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: labels,
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: "SELECT 1",
|
Query: "SELECT 1",
|
||||||
|
|||||||
@@ -12,186 +12,97 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package bigquery_test
|
package bigquery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
"github.com/google/go-cmp/cmp"
|
"google.golang.org/api/option"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
func TestNewBigQuerySessionProvider(t *testing.T) {
|
||||||
tcs := []struct {
|
var buf bytes.Buffer
|
||||||
desc string
|
logger, err := log.NewStdLogger(&buf, &buf, "DEBUG")
|
||||||
in string
|
|
||||||
want server.SourceConfigs
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "basic example",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
project: my-project
|
|
||||||
`,
|
|
||||||
want: server.SourceConfigs{
|
|
||||||
"my-instance": bigquery.Config{
|
|
||||||
Name: "my-instance",
|
|
||||||
Kind: bigquery.SourceKind,
|
|
||||||
Project: "my-project",
|
|
||||||
Location: "",
|
|
||||||
WriteMode: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "all fields specified",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
project: my-project
|
|
||||||
location: asia
|
|
||||||
writeMode: blocked
|
|
||||||
`,
|
|
||||||
want: server.SourceConfigs{
|
|
||||||
"my-instance": bigquery.Config{
|
|
||||||
Name: "my-instance",
|
|
||||||
Kind: bigquery.SourceKind,
|
|
||||||
Project: "my-project",
|
|
||||||
Location: "asia",
|
|
||||||
WriteMode: "blocked",
|
|
||||||
UseClientOAuth: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "use client auth example",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
project: my-project
|
|
||||||
location: us
|
|
||||||
useClientOAuth: true
|
|
||||||
`,
|
|
||||||
want: server.SourceConfigs{
|
|
||||||
"my-instance": bigquery.Config{
|
|
||||||
Name: "my-instance",
|
|
||||||
Kind: bigquery.SourceKind,
|
|
||||||
Project: "my-project",
|
|
||||||
Location: "us",
|
|
||||||
UseClientOAuth: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "with allowed datasets example",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
project: my-project
|
|
||||||
location: us
|
|
||||||
allowedDatasets:
|
|
||||||
- my_dataset
|
|
||||||
`,
|
|
||||||
want: server.SourceConfigs{
|
|
||||||
"my-instance": bigquery.Config{
|
|
||||||
Name: "my-instance",
|
|
||||||
Kind: bigquery.SourceKind,
|
|
||||||
Project: "my-project",
|
|
||||||
Location: "us",
|
|
||||||
AllowedDatasets: []string{"my_dataset"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "with service account impersonation example",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
project: my-project
|
|
||||||
location: us
|
|
||||||
impersonateServiceAccount: service-account@my-project.iam.gserviceaccount.com
|
|
||||||
`,
|
|
||||||
want: server.SourceConfigs{
|
|
||||||
"my-instance": bigquery.Config{
|
|
||||||
Name: "my-instance",
|
|
||||||
Kind: bigquery.SourceKind,
|
|
||||||
Project: "my-project",
|
|
||||||
Location: "us",
|
|
||||||
ImpersonateServiceAccount: "service-account@my-project.iam.gserviceaccount.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tcs {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
got := struct {
|
|
||||||
Sources server.SourceConfigs `yaml:"sources"`
|
|
||||||
}{}
|
|
||||||
// Parse contents
|
|
||||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to unmarshal: %s", err)
|
t.Fatalf("failed to create logger: %v", err)
|
||||||
}
|
}
|
||||||
if !cmp.Equal(tc.want, got.Sources) {
|
ctx := util.WithLogger(context.Background(), logger)
|
||||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
projectID := "test-project"
|
||||||
|
location := "us"
|
||||||
|
toolName := "test-tool"
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/projects/test-project/jobs" {
|
||||||
|
t.Errorf("expected to request '/projects/test-project/jobs', got: %s", r.URL.Path)
|
||||||
}
|
}
|
||||||
})
|
var job bigqueryrestapi.Job
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&job); err != nil {
|
||||||
|
t.Fatalf("failed to decode request body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expectedLabels := map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
if !reflect.DeepEqual(job.Configuration.Labels, expectedLabels) {
|
||||||
|
t.Errorf("expected labels %v, got %v", expectedLabels, job.Configuration.Labels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFailParseFromYaml(t *testing.T) {
|
if !job.Configuration.Query.CreateSession {
|
||||||
tcs := []struct {
|
t.Errorf("expected CreateSession to be true")
|
||||||
desc string
|
}
|
||||||
in string
|
|
||||||
err string
|
// Send back a dummy response
|
||||||
}{
|
w.Header().Set("Content-Type", "application/json")
|
||||||
{
|
json.NewEncoder(w).Encode(&bigqueryrestapi.Job{
|
||||||
desc: "extra field",
|
JobReference: &bigqueryrestapi.JobReference{
|
||||||
in: `
|
ProjectId: projectID,
|
||||||
sources:
|
JobId: "job_123",
|
||||||
my-instance:
|
Location: location,
|
||||||
kind: bigquery
|
},
|
||||||
project: my-project
|
Status: &bigqueryrestapi.JobStatus{
|
||||||
location: us
|
State: "DONE",
|
||||||
foo: bar
|
},
|
||||||
`,
|
Statistics: &bigqueryrestapi.JobStatistics{
|
||||||
err: "unable to parse source \"my-instance\" as \"bigquery\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: bigquery\n 3 | location: us\n 4 | project: my-project",
|
SessionInfo: &bigqueryrestapi.SessionInfo{
|
||||||
|
SessionId: "session_123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
|
DestinationTable: &bigqueryrestapi.TableReference{
|
||||||
|
ProjectId: projectID,
|
||||||
|
DatasetId: "dataset_123",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
desc: "missing required field",
|
|
||||||
in: `
|
|
||||||
sources:
|
|
||||||
my-instance:
|
|
||||||
kind: bigquery
|
|
||||||
location: us
|
|
||||||
`,
|
|
||||||
err: "unable to parse source \"my-instance\" as \"bigquery\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
|
||||||
},
|
},
|
||||||
}
|
|
||||||
for _, tc := range tcs {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
got := struct {
|
|
||||||
Sources server.SourceConfigs `yaml:"sources"`
|
|
||||||
}{}
|
|
||||||
// Parse contents
|
|
||||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expect parsing to fail")
|
|
||||||
}
|
|
||||||
errStr := err.Error()
|
|
||||||
if errStr != tc.err {
|
|
||||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(server.URL), option.WithHTTPClient(http.DefaultClient))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create test service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Source{
|
||||||
|
Config: Config{
|
||||||
|
Project: projectID,
|
||||||
|
Location: location,
|
||||||
|
WriteMode: WriteModeProtected,
|
||||||
|
},
|
||||||
|
RestService: restService,
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionProvider := s.newBigQuerySessionProvider()
|
||||||
|
_, err = sessionProvider(ctx, toolName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sessionProvider failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(t.AllowedDatasets) > 0 {
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -240,7 +240,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
{Key: "session_id", Value: session.ID},
|
{Key: "session_id", Value: session.ID},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -289,10 +289,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
)
|
)
|
||||||
|
|
||||||
createModelQuery := bqClient.Query(createModelSQL)
|
createModelQuery := bqClient.Query(createModelSQL)
|
||||||
|
createModelQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
// Get session from provider if in protected mode.
|
// Get session from provider if in protected mode.
|
||||||
// Otherwise, a new session will be created by the first query.
|
// Otherwise, a new session will be created by the first query.
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -332,6 +333,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||||
|
|
||||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||||
|
getInsightsQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||||
|
|
||||||
job, err := getInsightsQuery.Run(ctx)
|
job, err := getInsightsQuery.Run(ctx)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
||||||
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (*bigqueryrestapi.Job, error) {
|
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty, toolName string) (*bigqueryrestapi.Job, error) {
|
||||||
useLegacySql := false
|
useLegacySql := false
|
||||||
|
|
||||||
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
||||||
@@ -40,6 +40,7 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
|||||||
Location: location,
|
Location: location,
|
||||||
},
|
},
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: getLabels(toolName),
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: sql,
|
Query: sql,
|
||||||
@@ -57,6 +58,10 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
|||||||
return insertResponse, nil
|
return insertResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLabels(toolName string) map[string]string {
|
||||||
|
return map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
}
|
||||||
|
|
||||||
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
||||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||||
switch toolType {
|
switch toolType {
|
||||||
|
|||||||
90
internal/tools/bigquery/bigquerycommon/util_test.go
Normal file
90
internal/tools/bigquery/bigquerycommon/util_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
// Copyright 2025 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package bigquerycommon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"google.golang.org/api/option"
|
||||||
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetLabels(t *testing.T) {
|
||||||
|
toolName := "test-tool"
|
||||||
|
expected := map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
actual := getLabels(toolName)
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("getLabels() = %v, want %v", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDryRunQuery(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
projectID := "test-project"
|
||||||
|
location := "us"
|
||||||
|
sql := "SELECT 1"
|
||||||
|
toolName := "test-tool"
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/projects/test-project/jobs" {
|
||||||
|
t.Errorf("expected to request '/projects/test-project/jobs', got: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
var job bigqueryrestapi.Job
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&job); err != nil {
|
||||||
|
t.Fatalf("failed to decode request body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedLabels := map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
if !reflect.DeepEqual(job.Configuration.Labels, expectedLabels) {
|
||||||
|
t.Errorf("expected labels %v, got %v", expectedLabels, job.Configuration.Labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !job.Configuration.DryRun {
|
||||||
|
t.Errorf("expected DryRun to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send back a dummy response
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(&bigqueryrestapi.Job{
|
||||||
|
JobReference: &bigqueryrestapi.JobReference{
|
||||||
|
ProjectId: projectID,
|
||||||
|
JobId: "job_123",
|
||||||
|
Location: location,
|
||||||
|
},
|
||||||
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
DryRun: true,
|
||||||
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
|
Query: sql,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
restService, err := bigqueryrestapi.NewService(ctx, option.WithEndpoint(server.URL), option.WithHTTPClient(http.DefaultClient))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create test service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = DryRunQuery(ctx, restService, projectID, location, sql, nil, nil, toolName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DryRunQuery failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -103,6 +103,7 @@ type CAPayload struct {
|
|||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
InlineContext InlineContext `json:"inlineContext"`
|
InlineContext InlineContext `json:"inlineContext"`
|
||||||
ClientIdEnum string `json:"clientIdEnum"`
|
ClientIdEnum string `json:"clientIdEnum"`
|
||||||
|
JobLabels map[string]string `json:"jobLabels,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -276,6 +277,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||||
},
|
},
|
||||||
ClientIdEnum: "GENAI_TOOLBOX",
|
ClientIdEnum: "GENAI_TOOLBOX",
|
||||||
|
JobLabels: map[string]string{"genai-toolbox-tool": kind},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the streaming API
|
// Call the streaming API
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
var session *bigqueryds.Session
|
var session *bigqueryds.Session
|
||||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||||
session, err = t.SessionProvider(ctx)
|
session, err = t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||||
}
|
}
|
||||||
@@ -214,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -303,6 +303,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
|
|
||||||
query := bqClient.Query(sql)
|
query := bqClient.Query(sql)
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
query.ConnectionProperties = connProps
|
query.ConnectionProperties = connProps
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2025 Google LLC
|
// Copyright 2024 Google LLC
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -11,63 +11,101 @@
|
|||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
package bigqueryexecutesql
|
||||||
package bigqueryexecutesql_test
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
"cloud.google.com/go/bigquery"
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseFromYamlBigQueryExecuteSql(t *testing.T) {
|
func TestParseFromYaml(t *testing.T) {
|
||||||
ctx, err := testutils.ContextWithNewLogger()
|
t.Parallel()
|
||||||
if err != nil {
|
const (
|
||||||
t.Fatalf("unexpected error: %s", err)
|
basicYAML = `
|
||||||
}
|
name: bigquery-execute-sql-tool
|
||||||
tcs := []struct {
|
kind: bigquery-execute-sql
|
||||||
desc string
|
source: bq
|
||||||
in string
|
description: test
|
||||||
want server.ToolConfigs
|
authRequired:
|
||||||
|
- gcp
|
||||||
|
`
|
||||||
|
)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
isError bool
|
||||||
|
want *Config
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "basic example",
|
name: "basic example",
|
||||||
in: `
|
input: basicYAML,
|
||||||
tools:
|
want: &Config{
|
||||||
example_tool:
|
Name: "bigquery-execute-sql-tool",
|
||||||
kind: bigquery-execute-sql
|
|
||||||
source: my-instance
|
|
||||||
description: some description
|
|
||||||
`,
|
|
||||||
want: server.ToolConfigs{
|
|
||||||
"example_tool": bigqueryexecutesql.Config{
|
|
||||||
Name: "example_tool",
|
|
||||||
Kind: "bigquery-execute-sql",
|
Kind: "bigquery-execute-sql",
|
||||||
Source: "my-instance",
|
Source: "bq",
|
||||||
Description: "some description",
|
Description: "test",
|
||||||
AuthRequired: []string{},
|
AuthRequired: []string{"gcp"},
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tests {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
tc := tc
|
||||||
got := struct {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
Tools server.ToolConfigs `yaml:"tools"`
|
t.Parallel()
|
||||||
}{}
|
var got Config
|
||||||
// Parse contents
|
err := yaml.Unmarshal([]byte(tc.input), &got)
|
||||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
if tc.isError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("yaml.Unmarshal got nil error, want error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to unmarshal: %s", err)
|
t.Fatalf("yaml.Unmarshal got unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
if diff := cmp.Diff(tc.want, &got); diff != "" {
|
||||||
t.Fatalf("incorrect parse: diff %v", diff)
|
t.Errorf("yaml.Unmarshal() returned diff (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockQuery struct {
|
||||||
|
labels map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *mockQuery) Run(ctx context.Context) (*bigquery.Job, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *mockQuery) Read(ctx context.Context) (*bigquery.RowIterator, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockBigQueryClient struct {
|
||||||
|
*bigquery.Client
|
||||||
|
t *testing.T
|
||||||
|
projectID string
|
||||||
|
location string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockBigQueryClient) Query(sql string) *bigquery.Query {
|
||||||
|
q := &bigquery.Query{}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockBigQueryClient) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockBigQueryClient) Project() string {
|
||||||
|
return c.projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockBigQueryClient) Location() string {
|
||||||
|
return c.location
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(t.AllowedDatasets) > 0 {
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -218,7 +218,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
{Key: "session_id", Value: session.ID},
|
{Key: "session_id", Value: session.ID},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -279,7 +279,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
// JobStatistics.QueryStatistics.StatementType
|
// JobStatistics.QueryStatistics.StatementType
|
||||||
query := bqClient.Query(sql)
|
query := bqClient.Query(sql)
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
session, err := t.SessionProvider(ctx)
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,10 +230,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
query := bqClient.Query(newStatement)
|
query := bqClient.Query(newStatement)
|
||||||
query.Parameters = highLevelParams
|
query.Parameters = highLevelParams
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||||
if t.SessionProvider != nil {
|
if t.SessionProvider != nil {
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -243,7 +244,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
query.ConnectionProperties = connProps
|
query.ConnectionProperties = connProps
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user