Compare commits

...

1 Commits

Author SHA1 Message Date
Yuan Teoh
9947077408 draft: self-contained tools 2025-10-07 22:14:56 -07:00
12 changed files with 249 additions and 465 deletions

View File

@@ -25,6 +25,7 @@ import (
dataplexapi "cloud.google.com/go/dataplex/apiv1" dataplexapi "cloud.google.com/go/dataplex/apiv1"
"github.com/goccy/go-yaml" "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -247,6 +248,24 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
} }
} }
func (s *Source) RetrieveBQClient(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
bqClient := s.Client
restService := s.RestService
// Initialize new client if using user OAuth token
if s.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = s.ClientCreator(tokenStr, true)
if err != nil {
return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
return bqClient, restService, nil
}
func initBigQueryConnection( func initBigQueryConnection(
ctx context.Context, ctx context.Context,
tracer trace.Tracer, tracer trace.Tracer,

View File

@@ -50,6 +50,7 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -122,14 +123,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired, Source: s,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -140,21 +136,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
// Invoke runs the contribution analysis. // Invoke runs the contribution analysis.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap() paramsMap := params.AsMap()
inputData, ok := paramsMap["input_data"].(string) inputData, ok := paramsMap["input_data"].(string)
if !ok { if !ok {
@@ -206,19 +197,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
inputDataSource, inputDataSource,
) )
bqClient := t.Client bqClient, _, err := s.RetrieveBQClient(accessToken)
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
createModelQuery := bqClient.Query(createModelSQL) createModelQuery := bqClient.Query(createModelSQL)
@@ -299,5 +280,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -19,7 +19,9 @@ import (
"fmt" "fmt"
bigqueryapi "cloud.google.com/go/bigquery" bigqueryapi "cloud.google.com/go/bigquery"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2" bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
) )
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata. // DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
@@ -53,3 +55,46 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
} }
return insertResponse, nil return insertResponse, nil
} }
func RunQuery(ctx context.Context, statement string, query *bigqueryapi.Query) (any, error) {
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing big query execute sql query: %s", statement)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}

View File

@@ -23,7 +23,6 @@ import (
"net/http" "net/http"
"strings" "strings"
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml" yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
@@ -53,7 +52,6 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
} }
type compatibleSource interface { type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) BigQueryTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
BigQueryProject() string BigQueryProject() string
BigQueryLocation() string BigQueryLocation() string
@@ -151,33 +149,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
parameters := tools.Parameters{userQueryParameter, tableRefsParameter} parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
// Get cloud-platform token source for Gemini Data Analytics API during initialization
var bigQueryTokenSourceWithScope oauth2.TokenSource
if !s.UseClientAuthorization() {
ctx := context.Background()
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
bigQueryTokenSourceWithScope = ts
}
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind, Source: s,
Project: s.BigQueryProject(),
Location: s.BigQueryLocation(),
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired,
Client: s.BigQueryClient(),
UseClientOAuth: s.UseClientAuthorization(),
TokenSource: bigQueryTokenSourceWithScope,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
} }
return t, nil return t, nil
} }
@@ -186,29 +164,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Project string
Location string
Client *bigqueryapi.Client
TokenSource oauth2.TokenSource
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
var tokenStr string var tokenStr string
var err error var err error
// Get credentials for the API call // Get credentials for the API call
if t.UseClientOAuth { if s.UseClientAuthorization() {
// Use client-side access token // Use client-side access token
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized) return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
@@ -218,11 +187,15 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, fmt.Errorf("error parsing access token: %w", err)
} }
} else { } else {
tokenSource, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
// Use cloud-platform token source for Gemini Data Analytics API // Use cloud-platform token source for Gemini Data Analytics API
if t.TokenSource == nil { if tokenSource == nil {
return nil, fmt.Errorf("cloud-platform token source is missing") return nil, fmt.Errorf("cloud-platform token source is missing")
} }
token, err := t.TokenSource.Token() token, err := tokenSource.Token()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err) return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
} }
@@ -243,17 +216,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
} }
} }
if len(t.AllowedDatasets) > 0 { allowedDataset := s.BigQueryAllowedDatasets()
if len(allowedDataset) > 0 {
for _, tableRef := range tableRefs { for _, tableRef := range tableRefs {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { if !s.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID) return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
} }
} }
} }
// Construct URL, headers, and payload // Construct URL, headers, and payload
projectID := t.Project projectID := s.BigQueryProject()
location := t.Location location := s.BigQueryLocation()
if location == "" { if location == "" {
location = "us" location = "us"
} }
@@ -277,7 +251,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
} }
// Call the streaming API // Call the streaming API
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows) response, err := getStream(caURL, payload, headers, s.GetMaxQueryResultRows())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err) return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
} }
@@ -302,7 +276,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }
// StreamMessage represents a single message object from the streaming API response. // StreamMessage represents a single message object from the streaming API response.

View File

@@ -26,9 +26,7 @@ import (
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2" bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
) )
const kind string = "bigquery-execute-sql" const kind string = "bigquery-execute-sql"
@@ -54,6 +52,7 @@ type compatibleSource interface {
UseClientAuthorization() bool UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -122,16 +121,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired, Source: s,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -142,22 +134,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap() paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string) sql, ok := paramsMap["sql"].(string)
if !ok { if !ok {
@@ -168,20 +153,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
} }
bqClient := t.Client bqClient, restService, err := s.RetrieveBQClient(accessToken)
restService := t.RestService
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil) dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
@@ -189,8 +163,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("query validation failed during dry run: %w", err) return nil, fmt.Errorf("query validation failed during dry run: %w", err)
} }
statementType := dryRunJob.Statistics.Query.StatementType statementType := dryRunJob.Statistics.Query.StatementType
if len(s.BigQueryAllowedDatasets()) > 0 {
if len(t.AllowedDatasets) > 0 {
switch statementType { switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
@@ -225,7 +198,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
} else if statementType != "SELECT" { } else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements // If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE. // to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project()) parsedTables, parseErr := bqutil.TableParser(sql, t.Source.BigQueryClient().Project())
if parseErr != nil { if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail. // If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr) return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
@@ -237,7 +210,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
parts := strings.Split(tableID, ".") parts := strings.Split(tableID, ".")
if len(parts) == 3 { if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1] projectID, datasetID := parts[0], parts[1]
if !t.IsDatasetAllowed(projectID, datasetID) { if !s.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID) return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
} }
} }
@@ -259,51 +232,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query := bqClient.Query(sql) query := bqClient.Query(sql)
query.Location = bqClient.Location query.Location = bqClient.Location
// Log the query executed for debugging. return bqutil.RunQuery(ctx, sql, query)
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
} }
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
@@ -323,5 +252,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -53,6 +53,7 @@ type compatibleSource interface {
UseClientAuthorization() bool UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -114,16 +115,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -134,22 +127,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap() paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string) historyData, ok := paramsMap["history_data"].(string)
if !ok { if !ok {
@@ -187,8 +173,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var historyDataSource string var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData)) trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") { if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 { if len(s.BigQueryAllowedDatasets()) > 0 {
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil) dryRunJob, err := bqutil.DryRunQuery(ctx, s.BigQueryRestService(), s.BigQueryClient().Project(), s.BigQueryClient().Location, historyData, nil, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err) return nil, fmt.Errorf("query validation failed during dry run: %w", err)
} }
@@ -200,7 +186,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
queryStats := dryRunJob.Statistics.Query queryStats := dryRunJob.Statistics.Query
if queryStats != nil { if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables { for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { if !s.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
} }
} }
@@ -210,7 +196,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
} }
historyDataSource = fmt.Sprintf("(%s)", historyData) historyDataSource = fmt.Sprintf("(%s)", historyData)
} else { } else {
if len(t.AllowedDatasets) > 0 { if len(s.BigQueryAllowedDatasets()) > 0 {
parts := strings.Split(historyData, ".") parts := strings.Split(historyData, ".")
var projectID, datasetID string var projectID, datasetID string
@@ -219,13 +205,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
projectID = parts[0] projectID = parts[0]
datasetID = parts[1] datasetID = parts[1]
case 2: // dataset.table case 2: // dataset.table
projectID = t.Client.Project() projectID = s.BigQueryClient().Project()
datasetID = parts[0] datasetID = parts[0]
default: default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData) return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
} }
if !t.IsDatasetAllowed(projectID, datasetID) { if !s.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData) return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
} }
} }
@@ -246,19 +232,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
horizon => %d%s)`, horizon => %d%s)`,
historyDataSource, dataCol, timestampCol, horizon, idColsArg) historyDataSource, dataCol, timestampCol, horizon, idColsArg)
bqClient := t.Client bqClient, _, err := s.RetrieveBQClient(accessToken)
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
// JobStatistics.QueryStatistics.StatementType // JobStatistics.QueryStatistics.StatementType
@@ -321,5 +297,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
) )
const kind string = "bigquery-get-dataset-info" const kind string = "bigquery-get-dataset-info"
@@ -48,6 +49,7 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -91,13 +93,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired, Source: s,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -108,20 +106,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string Statement string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
mapParams := params.AsMap() mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string) projectId, ok := mapParams[projectKey].(string)
if !ok { if !ok {
@@ -133,19 +127,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
} }
bqClient := t.Client bqClient, _, err := s.RetrieveBQClient(accessToken)
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
dsHandle := bqClient.DatasetInProject(projectId, datasetId) dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -174,5 +158,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
) )
const kind string = "bigquery-get-table-info" const kind string = "bigquery-get-table-info"
@@ -49,6 +50,7 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -93,13 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -110,14 +107,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client Source compatibleSource
ClientCreator bigqueryds.BigqueryClientCreator
Statement string Statement string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
@@ -140,19 +133,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
} }
bqClient := t.Client bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
dsHandle := bqClient.DatasetInProject(projectId, datasetId) dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -183,5 +166,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator" "google.golang.org/api/iterator"
) )
@@ -48,6 +49,7 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -91,13 +93,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -108,14 +105,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client Source compatibleSource
ClientCreator bigqueryds.BigqueryClientCreator
Statement string Statement string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
@@ -128,17 +121,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
} }
bqClient := t.Client bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
datasetIterator := bqClient.Datasets(ctx) datasetIterator := bqClient.Datasets(ctx)
datasetIterator.ProjectID = projectId datasetIterator.ProjectID = projectId
@@ -181,5 +166,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -25,6 +25,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator" "google.golang.org/api/iterator"
) )
@@ -53,6 +54,7 @@ type compatibleSource interface {
UseClientAuthorization() bool UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -132,14 +134,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -150,21 +146,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"` Parameters tools.Parameters
AuthRequired []string `yaml:"authRequired"` Source compatibleSource
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
Statement string Statement string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
mapParams := params.AsMap() mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string) projectId, ok := mapParams[projectKey].(string)
if !ok { if !ok {
@@ -176,21 +167,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
} }
if !t.IsDatasetAllowed(projectId, datasetId) { if !s.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
} }
bqClient := t.Client bqClient, _, err := s.RetrieveBQClient(accessToken)
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
dsHandle := bqClient.DatasetInProject(projectId, datasetId) dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -234,5 +217,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }

View File

@@ -82,9 +82,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
} }
// Get the Dataplex client using the method from the source
makeCatalogClient := s.MakeDataplexCatalogClient()
prompt := tools.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.") prompt := tools.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
datasetIds := tools.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", tools.NewStringParameter("datasetId", "The IDs of the bigquery dataset.")) datasetIds := tools.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", tools.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
projectIds := tools.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", tools.NewStringParameter("projectId", "The IDs of the bigquery project.")) projectIds := tools.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", tools.NewStringParameter("projectId", "The IDs of the bigquery project."))
@@ -99,13 +96,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, parameters)
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
Parameters: parameters, Parameters: parameters,
AuthRequired: cfg.AuthRequired, Source: s,
UseClientOAuth: s.UseClientAuthorization(),
MakeCatalogClient: makeCatalogClient,
ProjectID: s.BigQueryProject(),
manifest: tools.Manifest{ manifest: tools.Manifest{
Description: cfg.Description, Description: cfg.Description,
Parameters: parameters.Manifest(), Parameters: parameters.Manifest(),
@@ -117,13 +110,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
} }
type Tool struct { type Tool struct {
Name string Config
Kind string
Parameters tools.Parameters Parameters tools.Parameters
AuthRequired []string Source compatibleSource
UseClientOAuth bool
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
ProjectID string
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
@@ -133,7 +122,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }
func constructSearchQueryHelper(predicate string, operator string, items []string) string { func constructSearchQueryHelper(predicate string, operator string, items []string) string {
@@ -206,6 +195,7 @@ func ExtractType(resourceString string) string {
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap() paramsMap := params.AsMap()
pageSize := int32(paramsMap["pageSize"].(int)) pageSize := int32(paramsMap["pageSize"].(int))
prompt, _ := paramsMap["prompt"].(string) prompt, _ := paramsMap["prompt"].(string)
@@ -227,14 +217,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
req := &dataplexpb.SearchEntriesRequest{ req := &dataplexpb.SearchEntriesRequest{
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)), Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), Name: fmt.Sprintf("projects/%s/locations/global", s.BigQueryProject()),
PageSize: pageSize, PageSize: pageSize,
SemanticSearch: true, SemanticSearch: true,
} }
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient() catalogClient, dataplexClientCreator, _ := s.MakeDataplexCatalogClient()()
if t.UseClientOAuth { if s.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken() tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -247,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
it := catalogClient.SearchEntries(ctx, req) it := catalogClient.SearchEntries(ctx, req)
if it == nil { if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.BigQueryProject())
} }
var results []Response var results []Response

View File

@@ -28,7 +28,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
bigqueryrestapi "google.golang.org/api/bigquery/v2" bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
) )
const kind string = "bigquery-sql" const kind string = "bigquery-sql"
@@ -52,6 +51,7 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
} }
// validate compatible sources are still compatible // validate compatible sources are still compatible
@@ -99,18 +99,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup // finish tool setup
t := Tool{ t := Tool{
Name: cfg.Name, Config: cfg,
Kind: kind,
AuthRequired: cfg.AuthRequired,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters, AllParams: allParameters,
Source: s,
Statement: cfg.Statement,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest, mcpManifest: mcpManifest,
} }
@@ -121,32 +112,24 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{} var _ tools.Tool = Tool{}
type Tool struct { type Tool struct {
Name string `yaml:"name"` Config
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"` AllParams tools.Parameters `yaml:"allParams"`
Source compatibleSource
Statement string Statement string
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest manifest tools.Manifest
mcpManifest tools.McpManifest mcpManifest tools.McpManifest
} }
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
paramsMap := params.AsMap() paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Config.Statement, paramsMap)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err) return nil, fmt.Errorf("unable to extract template params %w", err)
} }
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
for _, p := range t.Parameters { for _, p := range t.Parameters {
name := p.GetName() name := p.GetName()
value := paramsMap[name] value := paramsMap[name]
@@ -214,71 +197,23 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
lowLevelParams = append(lowLevelParams, lowLevelParam) lowLevelParams = append(lowLevelParams, lowLevelParam)
} }
bqClient := t.Client bqClient, restService, err := t.Source.RetrieveBQClient(accessToken)
restService := t.RestService
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil { if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err) return nil, err
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
} }
query := bqClient.Query(newStatement) query := bqClient.Query(newStatement)
query.Parameters = highLevelParams query.Parameters = highLevelParams
query.Location = bqClient.Location query.Location = bqClient.Location
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties) _, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil { if err != nil {
// This is a fallback check in case the switch logic was bypassed. // This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err) return nil, fmt.Errorf("final query validation failed: %w", err)
} }
statementType := dryRunJob.Statistics.Query.StatementType
// This block handles SELECT statements, which return a row set. return bqutil.RunQuery(ctx, newStatement, query)
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
var out []any
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
} }
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
@@ -298,7 +233,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
} }
func (t Tool) RequiresClientAuthorization() bool { func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth return t.Source.UseClientAuthorization()
} }
func BQTypeStringFromToolType(toolType string) (string, error) { func BQTypeStringFromToolType(toolType string) (string, error) {