mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-29 09:18:12 -05:00
Compare commits
1 Commits
processing
...
bq-invoke-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9947077408 |
@@ -25,6 +25,7 @@ import (
|
|||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -247,6 +248,24 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RetrieveBQClient(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||||
|
bqClient := s.Client
|
||||||
|
restService := s.RestService
|
||||||
|
|
||||||
|
// Initialize new client if using user OAuth token
|
||||||
|
if s.UseClientOAuth {
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
|
bqClient, restService, err = s.ClientCreator(tokenStr, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bqClient, restService, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initBigQueryConnection(
|
func initBigQueryConnection(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tracer trace.Tracer,
|
tracer trace.Tracer,
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ type compatibleSource interface {
|
|||||||
BigQueryRestService() *bigqueryrestapi.Service
|
BigQueryRestService() *bigqueryrestapi.Service
|
||||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -122,16 +123,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
Source: s,
|
||||||
AuthRequired: cfg.AuthRequired,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
mcpManifest: mcpManifest,
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
RestService: s.BigQueryRestService(),
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -140,21 +136,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
manifest tools.Manifest
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
mcpManifest tools.McpManifest
|
||||||
|
|
||||||
Client *bigqueryapi.Client
|
|
||||||
RestService *bigqueryrestapi.Service
|
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke runs the contribution analysis.
|
// Invoke runs the contribution analysis.
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
inputData, ok := paramsMap["input_data"].(string)
|
inputData, ok := paramsMap["input_data"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -206,19 +197,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
inputDataSource,
|
inputDataSource,
|
||||||
)
|
)
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := s.RetrieveBQClient(accessToken)
|
||||||
var err error
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
createModelQuery := bqClient.Query(createModelSQL)
|
createModelQuery := bqClient.Query(createModelSQL)
|
||||||
@@ -299,5 +280,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
|
"google.golang.org/api/iterator"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
||||||
@@ -53,3 +55,46 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
|||||||
}
|
}
|
||||||
return insertResponse, nil
|
return insertResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RunQuery(ctx context.Context, statement string, query *bigqueryapi.Query) (any, error) {
|
||||||
|
// Log the query executed for debugging.
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, "executing big query execute sql query: %s", statement)
|
||||||
|
|
||||||
|
// This block handles SELECT statements, which return a row set.
|
||||||
|
// We iterate through the results, convert each row into a map of
|
||||||
|
// column names to values, and return the collection of rows.
|
||||||
|
var out []any
|
||||||
|
it, err := query.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
var row map[string]bigqueryapi.Value
|
||||||
|
err = it.Next(&row)
|
||||||
|
if err == iterator.Done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
|
||||||
|
}
|
||||||
|
vMap := make(map[string]any)
|
||||||
|
for key, value := range row {
|
||||||
|
vMap[key] = value
|
||||||
|
}
|
||||||
|
out = append(out, vMap)
|
||||||
|
}
|
||||||
|
// If the query returned any rows, return them directly.
|
||||||
|
if len(out) > 0 {
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the fallback for a successful query that doesn't return content.
|
||||||
|
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
||||||
|
// However, it is also possible that this was a query that was expected to return rows
|
||||||
|
// but returned none, a case that we cannot distinguish here.
|
||||||
|
return "Query executed successfully and returned no content.", nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
@@ -53,7 +52,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
BigQueryClient() *bigqueryapi.Client
|
|
||||||
BigQueryTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
|
BigQueryTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
|
||||||
BigQueryProject() string
|
BigQueryProject() string
|
||||||
BigQueryLocation() string
|
BigQueryLocation() string
|
||||||
@@ -151,33 +149,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
|
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
|
||||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
||||||
|
|
||||||
// Get cloud-platform token source for Gemini Data Analytics API during initialization
|
|
||||||
var bigQueryTokenSourceWithScope oauth2.TokenSource
|
|
||||||
if !s.UseClientAuthorization() {
|
|
||||||
ctx := context.Background()
|
|
||||||
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
|
||||||
}
|
|
||||||
bigQueryTokenSourceWithScope = ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Source: s,
|
||||||
Project: s.BigQueryProject(),
|
Parameters: parameters,
|
||||||
Location: s.BigQueryLocation(),
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
Parameters: parameters,
|
mcpManifest: mcpManifest,
|
||||||
AuthRequired: cfg.AuthRequired,
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
TokenSource: bigQueryTokenSourceWithScope,
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
|
||||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
|
||||||
AllowedDatasets: allowedDatasets,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -186,29 +164,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
manifest tools.Manifest
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
mcpManifest tools.McpManifest
|
||||||
|
|
||||||
Project string
|
|
||||||
Location string
|
|
||||||
Client *bigqueryapi.Client
|
|
||||||
TokenSource oauth2.TokenSource
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
MaxQueryResultRows int
|
|
||||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
|
||||||
AllowedDatasets []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
var tokenStr string
|
var tokenStr string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Get credentials for the API call
|
// Get credentials for the API call
|
||||||
if t.UseClientOAuth {
|
if s.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
|
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
|
||||||
@@ -218,11 +187,15 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
tokenSource, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
|
||||||
|
}
|
||||||
// Use cloud-platform token source for Gemini Data Analytics API
|
// Use cloud-platform token source for Gemini Data Analytics API
|
||||||
if t.TokenSource == nil {
|
if tokenSource == nil {
|
||||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
return nil, fmt.Errorf("cloud-platform token source is missing")
|
||||||
}
|
}
|
||||||
token, err := t.TokenSource.Token()
|
token, err := tokenSource.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
||||||
}
|
}
|
||||||
@@ -243,17 +216,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(t.AllowedDatasets) > 0 {
|
allowedDataset := s.BigQueryAllowedDatasets()
|
||||||
|
if len(allowedDataset) > 0 {
|
||||||
for _, tableRef := range tableRefs {
|
for _, tableRef := range tableRefs {
|
||||||
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
if !s.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct URL, headers, and payload
|
// Construct URL, headers, and payload
|
||||||
projectID := t.Project
|
projectID := s.BigQueryProject()
|
||||||
location := t.Location
|
location := s.BigQueryLocation()
|
||||||
if location == "" {
|
if location == "" {
|
||||||
location = "us"
|
location = "us"
|
||||||
}
|
}
|
||||||
@@ -277,7 +251,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call the streaming API
|
// Call the streaming API
|
||||||
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
|
response, err := getStream(caURL, payload, headers, s.GetMaxQueryResultRows())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||||
}
|
}
|
||||||
@@ -302,7 +276,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamMessage represents a single message object from the streaming API response.
|
// StreamMessage represents a single message object from the streaming API response.
|
||||||
|
|||||||
@@ -26,9 +26,7 @@ import (
|
|||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "bigquery-execute-sql"
|
const kind string = "bigquery-execute-sql"
|
||||||
@@ -54,6 +52,7 @@ type compatibleSource interface {
|
|||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
IsDatasetAllowed(projectID, datasetID string) bool
|
IsDatasetAllowed(projectID, datasetID string) bool
|
||||||
BigQueryAllowedDatasets() []string
|
BigQueryAllowedDatasets() []string
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -122,18 +121,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
Source: s,
|
||||||
AuthRequired: cfg.AuthRequired,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
mcpManifest: mcpManifest,
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
RestService: s.BigQueryRestService(),
|
|
||||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
|
||||||
AllowedDatasets: allowedDatasets,
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -142,22 +134,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
manifest tools.Manifest
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
mcpManifest tools.McpManifest
|
||||||
|
|
||||||
Client *bigqueryapi.Client
|
|
||||||
RestService *bigqueryrestapi.Service
|
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
|
||||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
|
||||||
AllowedDatasets []string
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
sql, ok := paramsMap["sql"].(string)
|
sql, ok := paramsMap["sql"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -168,20 +153,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, restService, err := s.RetrieveBQClient(accessToken)
|
||||||
restService := t.RestService
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
var err error
|
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
|
||||||
@@ -189,8 +163,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||||
}
|
}
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
statementType := dryRunJob.Statistics.Query.StatementType
|
||||||
|
if len(s.BigQueryAllowedDatasets()) > 0 {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
|
||||||
switch statementType {
|
switch statementType {
|
||||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||||
@@ -225,7 +198,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
} else if statementType != "SELECT" {
|
} else if statementType != "SELECT" {
|
||||||
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
||||||
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
||||||
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
|
parsedTables, parseErr := bqutil.TableParser(sql, t.Source.BigQueryClient().Project())
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||||
@@ -237,7 +210,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
parts := strings.Split(tableID, ".")
|
parts := strings.Split(tableID, ".")
|
||||||
if len(parts) == 3 {
|
if len(parts) == 3 {
|
||||||
projectID, datasetID := parts[0], parts[1]
|
projectID, datasetID := parts[0], parts[1]
|
||||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
if !s.IsDatasetAllowed(projectID, datasetID) {
|
||||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -259,51 +232,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
query := bqClient.Query(sql)
|
query := bqClient.Query(sql)
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
|
||||||
// Log the query executed for debugging.
|
return bqutil.RunQuery(ctx, sql, query)
|
||||||
logger, err := util.LoggerFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
|
||||||
}
|
|
||||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
|
|
||||||
|
|
||||||
// This block handles SELECT statements, which return a row set.
|
|
||||||
// We iterate through the results, convert each row into a map of
|
|
||||||
// column names to values, and return the collection of rows.
|
|
||||||
var out []any
|
|
||||||
it, err := query.Read(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var row map[string]bigqueryapi.Value
|
|
||||||
err = it.Next(&row)
|
|
||||||
if err == iterator.Done {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for key, value := range row {
|
|
||||||
vMap[key] = value
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
// If the query returned any rows, return them directly.
|
|
||||||
if len(out) > 0 {
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This handles the standard case for a SELECT query that successfully
|
|
||||||
// executes but returns zero rows.
|
|
||||||
if statementType == "SELECT" {
|
|
||||||
return "The query returned 0 rows.", nil
|
|
||||||
}
|
|
||||||
// This is the fallback for a successful query that doesn't return content.
|
|
||||||
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
|
||||||
// However, it is also possible that this was a query that was expected to return rows
|
|
||||||
// but returned none, a case that we cannot distinguish here.
|
|
||||||
return "Query executed successfully and returned no content.", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
@@ -323,5 +252,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ type compatibleSource interface {
|
|||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
IsDatasetAllowed(projectID, datasetID string) bool
|
IsDatasetAllowed(projectID, datasetID string) bool
|
||||||
BigQueryAllowedDatasets() []string
|
BigQueryAllowedDatasets() []string
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -114,18 +115,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
AuthRequired: cfg.AuthRequired,
|
mcpManifest: mcpManifest,
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
RestService: s.BigQueryRestService(),
|
|
||||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
|
||||||
AllowedDatasets: allowedDatasets,
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -134,22 +127,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
manifest tools.Manifest
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
mcpManifest tools.McpManifest
|
||||||
|
|
||||||
Client *bigqueryapi.Client
|
|
||||||
RestService *bigqueryrestapi.Service
|
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
|
||||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
|
||||||
AllowedDatasets []string
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
historyData, ok := paramsMap["history_data"].(string)
|
historyData, ok := paramsMap["history_data"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -187,8 +173,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
var historyDataSource string
|
var historyDataSource string
|
||||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(s.BigQueryAllowedDatasets()) > 0 {
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, s.BigQueryRestService(), s.BigQueryClient().Project(), s.BigQueryClient().Location, historyData, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||||
}
|
}
|
||||||
@@ -200,7 +186,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
queryStats := dryRunJob.Statistics.Query
|
queryStats := dryRunJob.Statistics.Query
|
||||||
if queryStats != nil {
|
if queryStats != nil {
|
||||||
for _, tableRef := range queryStats.ReferencedTables {
|
for _, tableRef := range queryStats.ReferencedTables {
|
||||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
if !s.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -210,7 +196,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
}
|
}
|
||||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||||
} else {
|
} else {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(s.BigQueryAllowedDatasets()) > 0 {
|
||||||
parts := strings.Split(historyData, ".")
|
parts := strings.Split(historyData, ".")
|
||||||
var projectID, datasetID string
|
var projectID, datasetID string
|
||||||
|
|
||||||
@@ -219,13 +205,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
projectID = parts[0]
|
projectID = parts[0]
|
||||||
datasetID = parts[1]
|
datasetID = parts[1]
|
||||||
case 2: // dataset.table
|
case 2: // dataset.table
|
||||||
projectID = t.Client.Project()
|
projectID = s.BigQueryClient().Project()
|
||||||
datasetID = parts[0]
|
datasetID = parts[0]
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
if !s.IsDatasetAllowed(projectID, datasetID) {
|
||||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,19 +232,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
horizon => %d%s)`,
|
horizon => %d%s)`,
|
||||||
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
|
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := s.RetrieveBQClient(accessToken)
|
||||||
var err error
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobStatistics.QueryStatistics.StatementType
|
// JobStatistics.QueryStatistics.StatementType
|
||||||
@@ -321,5 +297,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "bigquery-get-dataset-info"
|
const kind string = "bigquery-get-dataset-info"
|
||||||
@@ -48,6 +49,7 @@ type compatibleSource interface {
|
|||||||
BigQueryClient() *bigqueryapi.Client
|
BigQueryClient() *bigqueryapi.Client
|
||||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -91,15 +93,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
Source: s,
|
||||||
AuthRequired: cfg.AuthRequired,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
mcpManifest: mcpManifest,
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -108,20 +106,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
Statement string
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
manifest tools.Manifest
|
||||||
|
mcpManifest tools.McpManifest
|
||||||
Client *bigqueryapi.Client
|
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
|
||||||
Statement string
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -133,19 +127,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := s.RetrieveBQClient(accessToken)
|
||||||
var err error
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
|
|
||||||
@@ -174,5 +158,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "bigquery-get-table-info"
|
const kind string = "bigquery-get-table-info"
|
||||||
@@ -49,6 +50,7 @@ type compatibleSource interface {
|
|||||||
BigQueryClient() *bigqueryapi.Client
|
BigQueryClient() *bigqueryapi.Client
|
||||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -93,15 +95,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
AuthRequired: cfg.AuthRequired,
|
mcpManifest: mcpManifest,
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -110,17 +107,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
|
||||||
|
|
||||||
Client *bigqueryapi.Client
|
Source compatibleSource
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
Statement string
|
||||||
Statement string
|
manifest tools.Manifest
|
||||||
manifest tools.Manifest
|
mcpManifest tools.McpManifest
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
@@ -140,19 +133,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
|
||||||
|
if err != nil {
|
||||||
var err error
|
return nil, err
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
@@ -183,5 +166,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
"google.golang.org/api/iterator"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ type compatibleSource interface {
|
|||||||
BigQueryClient() *bigqueryapi.Client
|
BigQueryClient() *bigqueryapi.Client
|
||||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -91,15 +93,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
AuthRequired: cfg.AuthRequired,
|
mcpManifest: mcpManifest,
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -108,17 +105,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
|
||||||
|
|
||||||
Client *bigqueryapi.Client
|
Source compatibleSource
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
Statement string
|
||||||
Statement string
|
manifest tools.Manifest
|
||||||
manifest tools.Manifest
|
mcpManifest tools.McpManifest
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
@@ -128,17 +121,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
|
||||||
// Initialize new client if using user OAuth token
|
if err != nil {
|
||||||
if t.UseClientOAuth {
|
return nil, err
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
datasetIterator := bqClient.Datasets(ctx)
|
datasetIterator := bqClient.Datasets(ctx)
|
||||||
datasetIterator.ProjectID = projectId
|
datasetIterator.ProjectID = projectId
|
||||||
@@ -181,5 +166,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
"google.golang.org/api/iterator"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ type compatibleSource interface {
|
|||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
IsDatasetAllowed(projectID, datasetID string) bool
|
IsDatasetAllowed(projectID, datasetID string) bool
|
||||||
BigQueryAllowedDatasets() []string
|
BigQueryAllowedDatasets() []string
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -132,16 +134,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
AuthRequired: cfg.AuthRequired,
|
mcpManifest: mcpManifest,
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -150,21 +146,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
Parameters tools.Parameters
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
Source compatibleSource
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
Statement string
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
manifest tools.Manifest
|
||||||
|
mcpManifest tools.McpManifest
|
||||||
Client *bigqueryapi.Client
|
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
|
||||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
|
||||||
Statement string
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -176,21 +167,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
if !s.IsDatasetAllowed(projectId, datasetId) {
|
||||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, _, err := s.RetrieveBQClient(accessToken)
|
||||||
// Initialize new client if using user OAuth token
|
if err != nil {
|
||||||
if t.UseClientOAuth {
|
return nil, err
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
@@ -234,5 +217,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -82,9 +82,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the Dataplex client using the method from the source
|
|
||||||
makeCatalogClient := s.MakeDataplexCatalogClient()
|
|
||||||
|
|
||||||
prompt := tools.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
|
prompt := tools.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
|
||||||
datasetIds := tools.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", tools.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
|
datasetIds := tools.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", tools.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
|
||||||
projectIds := tools.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", tools.NewStringParameter("projectId", "The IDs of the bigquery project."))
|
projectIds := tools.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", tools.NewStringParameter("projectId", "The IDs of the bigquery project."))
|
||||||
@@ -99,13 +96,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, parameters)
|
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, parameters)
|
||||||
|
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
Parameters: parameters,
|
||||||
Parameters: parameters,
|
Source: s,
|
||||||
AuthRequired: cfg.AuthRequired,
|
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
MakeCatalogClient: makeCatalogClient,
|
|
||||||
ProjectID: s.BigQueryProject(),
|
|
||||||
manifest: tools.Manifest{
|
manifest: tools.Manifest{
|
||||||
Description: cfg.Description,
|
Description: cfg.Description,
|
||||||
Parameters: parameters.Manifest(),
|
Parameters: parameters.Manifest(),
|
||||||
@@ -117,15 +110,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string
|
Config
|
||||||
Kind string
|
Parameters tools.Parameters
|
||||||
Parameters tools.Parameters
|
Source compatibleSource
|
||||||
AuthRequired []string
|
manifest tools.Manifest
|
||||||
UseClientOAuth bool
|
mcpManifest tools.McpManifest
|
||||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
|
||||||
ProjectID string
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||||
@@ -133,7 +122,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|
||||||
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
|
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
|
||||||
@@ -206,6 +195,7 @@ func ExtractType(resourceString string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
s := t.Source
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
pageSize := int32(paramsMap["pageSize"].(int))
|
pageSize := int32(paramsMap["pageSize"].(int))
|
||||||
prompt, _ := paramsMap["prompt"].(string)
|
prompt, _ := paramsMap["prompt"].(string)
|
||||||
@@ -227,14 +217,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
|
|
||||||
req := &dataplexpb.SearchEntriesRequest{
|
req := &dataplexpb.SearchEntriesRequest{
|
||||||
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
|
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
|
||||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
Name: fmt.Sprintf("projects/%s/locations/global", s.BigQueryProject()),
|
||||||
PageSize: pageSize,
|
PageSize: pageSize,
|
||||||
SemanticSearch: true,
|
SemanticSearch: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
|
catalogClient, dataplexClientCreator, _ := s.MakeDataplexCatalogClient()()
|
||||||
|
|
||||||
if t.UseClientOAuth {
|
if s.UseClientAuthorization() {
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
@@ -247,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
|
|
||||||
it := catalogClient.SearchEntries(ctx, req)
|
it := catalogClient.SearchEntries(ctx, req)
|
||||||
if it == nil {
|
if it == nil {
|
||||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
|
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.BigQueryProject())
|
||||||
}
|
}
|
||||||
|
|
||||||
var results []Response
|
var results []Response
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "bigquery-sql"
|
const kind string = "bigquery-sql"
|
||||||
@@ -52,6 +51,7 @@ type compatibleSource interface {
|
|||||||
BigQueryRestService() *bigqueryrestapi.Service
|
BigQueryRestService() *bigqueryrestapi.Service
|
||||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -99,20 +99,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
|
|
||||||
// finish tool setup
|
// finish tool setup
|
||||||
t := Tool{
|
t := Tool{
|
||||||
Name: cfg.Name,
|
Config: cfg,
|
||||||
Kind: kind,
|
AllParams: allParameters,
|
||||||
AuthRequired: cfg.AuthRequired,
|
Source: s,
|
||||||
Parameters: cfg.Parameters,
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||||
TemplateParameters: cfg.TemplateParameters,
|
mcpManifest: mcpManifest,
|
||||||
AllParams: allParameters,
|
|
||||||
|
|
||||||
Statement: cfg.Statement,
|
|
||||||
UseClientOAuth: s.UseClientAuthorization(),
|
|
||||||
Client: s.BigQueryClient(),
|
|
||||||
RestService: s.BigQueryRestService(),
|
|
||||||
ClientCreator: s.BigQueryClientCreator(),
|
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
@@ -121,32 +112,24 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
var _ tools.Tool = Tool{}
|
var _ tools.Tool = Tool{}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Name string `yaml:"name"`
|
Config
|
||||||
Kind string `yaml:"kind"`
|
AllParams tools.Parameters `yaml:"allParams"`
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
|
||||||
Parameters tools.Parameters `yaml:"parameters"`
|
|
||||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
|
||||||
AllParams tools.Parameters `yaml:"allParams"`
|
|
||||||
|
|
||||||
Statement string
|
Source compatibleSource
|
||||||
Client *bigqueryapi.Client
|
Statement string
|
||||||
RestService *bigqueryrestapi.Service
|
manifest tools.Manifest
|
||||||
ClientCreator bigqueryds.BigqueryClientCreator
|
mcpManifest tools.McpManifest
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
|
||||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Config.Statement, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||||
|
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||||
for _, p := range t.Parameters {
|
for _, p := range t.Parameters {
|
||||||
name := p.GetName()
|
name := p.GetName()
|
||||||
value := paramsMap[name]
|
value := paramsMap[name]
|
||||||
@@ -214,71 +197,23 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
|||||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient := t.Client
|
bqClient, restService, err := t.Source.RetrieveBQClient(accessToken)
|
||||||
restService := t.RestService
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
// Initialize new client if using user OAuth token
|
|
||||||
if t.UseClientOAuth {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
query := bqClient.Query(newStatement)
|
query := bqClient.Query(newStatement)
|
||||||
query.Parameters = highLevelParams
|
query.Parameters = highLevelParams
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
_, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This is a fallback check in case the switch logic was bypassed.
|
// This is a fallback check in case the switch logic was bypassed.
|
||||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
|
||||||
|
|
||||||
// This block handles SELECT statements, which return a row set.
|
return bqutil.RunQuery(ctx, newStatement, query)
|
||||||
// We iterate through the results, convert each row into a map of
|
|
||||||
// column names to values, and return the collection of rows.
|
|
||||||
it, err := query.Read(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var out []any
|
|
||||||
for {
|
|
||||||
var row map[string]bigqueryapi.Value
|
|
||||||
err = it.Next(&row)
|
|
||||||
if err == iterator.Done {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
|
|
||||||
}
|
|
||||||
vMap := make(map[string]any)
|
|
||||||
for key, value := range row {
|
|
||||||
vMap[key] = value
|
|
||||||
}
|
|
||||||
out = append(out, vMap)
|
|
||||||
}
|
|
||||||
// If the query returned any rows, return them directly.
|
|
||||||
if len(out) > 0 {
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// This handles the standard case for a SELECT query that successfully
|
|
||||||
// executes but returns zero rows.
|
|
||||||
if statementType == "SELECT" {
|
|
||||||
return "The query returned 0 rows.", nil
|
|
||||||
}
|
|
||||||
// This is the fallback for a successful query that doesn't return content.
|
|
||||||
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
|
||||||
// However, it is also possible that this was a query that was expected to return rows
|
|
||||||
// but returned none, a case that we cannot distinguish here.
|
|
||||||
return "Query executed successfully and returned no content.", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||||
@@ -298,7 +233,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization() bool {
|
func (t Tool) RequiresClientAuthorization() bool {
|
||||||
return t.UseClientOAuth
|
return t.Source.UseClientAuthorization()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user