mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-29 09:18:12 -05:00
Compare commits
1 Commits
processing
...
bigquery-j
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
605b3cdef0 |
@@ -52,7 +52,7 @@ var _ sources.SourceConfig = Config{}
|
|||||||
|
|
||||||
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||||
|
|
||||||
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
type BigQuerySessionProvider func(ctx context.Context, toolName string) (*Session, error)
|
||||||
|
|
||||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||||
|
|
||||||
@@ -287,7 +287,7 @@ func (s *Source) BigQuerySession() BigQuerySessionProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||||
return func(ctx context.Context) (*Session, error) {
|
return func(ctx context.Context, toolName string) (*Session, error) {
|
||||||
if s.WriteMode != WriteModeProtected {
|
if s.WriteMode != WriteModeProtected {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -300,6 +300,8 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
labels := map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
|
||||||
if s.Session != nil {
|
if s.Session != nil {
|
||||||
// Absolute 7-day lifetime check.
|
// Absolute 7-day lifetime check.
|
||||||
const sessionMaxLifetime = 7 * 24 * time.Hour
|
const sessionMaxLifetime = 7 * 24 * time.Hour
|
||||||
@@ -310,6 +312,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
} else {
|
} else {
|
||||||
job := &bigqueryrestapi.Job{
|
job := &bigqueryrestapi.Job{
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: labels,
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: "SELECT 1",
|
Query: "SELECT 1",
|
||||||
@@ -337,6 +340,7 @@ func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|||||||
Location: s.Location,
|
Location: s.Location,
|
||||||
},
|
},
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: labels,
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: "SELECT 1",
|
Query: "SELECT 1",
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(t.AllowedDatasets) > 0 {
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -240,7 +240,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
{Key: "session_id", Value: session.ID},
|
{Key: "session_id", Value: session.ID},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -289,10 +289,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
)
|
)
|
||||||
|
|
||||||
createModelQuery := bqClient.Query(createModelSQL)
|
createModelQuery := bqClient.Query(createModelSQL)
|
||||||
|
createModelQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
// Get session from provider if in protected mode.
|
// Get session from provider if in protected mode.
|
||||||
// Otherwise, a new session will be created by the first query.
|
// Otherwise, a new session will be created by the first query.
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -332,6 +333,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||||
|
|
||||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||||
|
getInsightsQuery.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||||
|
|
||||||
job, err := getInsightsQuery.Run(ctx)
|
job, err := getInsightsQuery.Run(ctx)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
||||||
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (*bigqueryrestapi.Job, error) {
|
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty, toolName string) (*bigqueryrestapi.Job, error) {
|
||||||
useLegacySql := false
|
useLegacySql := false
|
||||||
|
|
||||||
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
||||||
@@ -40,6 +40,7 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
|||||||
Location: location,
|
Location: location,
|
||||||
},
|
},
|
||||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||||
|
Labels: getLabels(toolName),
|
||||||
DryRun: true,
|
DryRun: true,
|
||||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||||
Query: sql,
|
Query: sql,
|
||||||
@@ -57,6 +58,10 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
|||||||
return insertResponse, nil
|
return insertResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLabels(toolName string) map[string]string {
|
||||||
|
return map[string]string{"genai-toolbox-tool": toolName}
|
||||||
|
}
|
||||||
|
|
||||||
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
||||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||||
switch toolType {
|
switch toolType {
|
||||||
|
|||||||
@@ -99,10 +99,11 @@ type InlineContext struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CAPayload struct {
|
type CAPayload struct {
|
||||||
Project string `json:"project"`
|
Project string `json:"project"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
InlineContext InlineContext `json:"inlineContext"`
|
InlineContext InlineContext `json:"inlineContext"`
|
||||||
ClientIdEnum string `json:"clientIdEnum"`
|
ClientIdEnum string `json:"clientIdEnum"`
|
||||||
|
JobLabels map[string]string `json:"jobLabels,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate compatible sources are still compatible
|
// validate compatible sources are still compatible
|
||||||
@@ -276,6 +277,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||||
},
|
},
|
||||||
ClientIdEnum: "GENAI_TOOLBOX",
|
ClientIdEnum: "GENAI_TOOLBOX",
|
||||||
|
JobLabels: map[string]string{"genai-toolbox-tool": kind},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the streaming API
|
// Call the streaming API
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
var session *bigqueryds.Session
|
var session *bigqueryds.Session
|
||||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||||
session, err = t.SessionProvider(ctx)
|
session, err = t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||||
}
|
}
|
||||||
@@ -214,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -303,6 +303,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
|
|
||||||
query := bqClient.Query(sql)
|
query := bqClient.Query(sql)
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
query.ConnectionProperties = connProps
|
query.ConnectionProperties = connProps
|
||||||
|
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||||
if len(t.AllowedDatasets) > 0 {
|
if len(t.AllowedDatasets) > 0 {
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -218,7 +218,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
{Key: "session_id", Value: session.ID},
|
{Key: "session_id", Value: session.ID},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -279,7 +279,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
// JobStatistics.QueryStatistics.StatementType
|
// JobStatistics.QueryStatistics.StatementType
|
||||||
query := bqClient.Query(sql)
|
query := bqClient.Query(sql)
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
session, err := t.SessionProvider(ctx)
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -230,10 +230,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
query := bqClient.Query(newStatement)
|
query := bqClient.Query(newStatement)
|
||||||
query.Parameters = highLevelParams
|
query.Parameters = highLevelParams
|
||||||
query.Location = bqClient.Location
|
query.Location = bqClient.Location
|
||||||
|
query.Labels = map[string]string{"genai-toolbox-tool": kind}
|
||||||
|
|
||||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||||
if t.SessionProvider != nil {
|
if t.SessionProvider != nil {
|
||||||
session, err := t.SessionProvider(ctx)
|
session, err := t.SessionProvider(ctx, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -243,7 +244,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
query.ConnectionProperties = connProps
|
query.ConnectionProperties = connProps
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps, kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user