mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-23 06:18:02 -05:00
feat(bigquery): make maximum rows returned from queries configurable (#2262)
This change allows the agent developer to control the maxium number of rows returned from tools running BigQuery SQL query. Using this feature the agent developer could limit how large output is presented to LLM in an agentic user journey. ## Description > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue https://github.com/googleapis/genai-toolbox/issues/2261 before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #2261 2261
This commit is contained in:
@@ -134,6 +134,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
@@ -153,6 +154,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -167,3 +169,4 @@ sources:
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
||||
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
||||
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |
|
||||
|
||||
@@ -19,6 +19,7 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
||||
scopes: ${BIGQUERY_SCOPES:}
|
||||
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
|
||||
@@ -89,6 +89,7 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||
Scopes StringOrStringSlice `yaml:"scopes"`
|
||||
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
|
||||
}
|
||||
|
||||
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
||||
@@ -127,6 +128,10 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.MaxQueryResultRows == 0 {
|
||||
r.MaxQueryResultRows = 50
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
// The protected mode only allows write operations to the session's temporary datasets.
|
||||
// when using client OAuth, a new session is created every
|
||||
@@ -150,7 +155,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
MaxQueryResultRows: r.MaxQueryResultRows,
|
||||
ClientCreator: clientCreator,
|
||||
}
|
||||
|
||||
@@ -567,7 +572,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
|
||||
var val []bigqueryapi.Value
|
||||
err = it.Next(&val)
|
||||
if err == iterator.Done {
|
||||
|
||||
@@ -21,9 +21,12 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
@@ -154,6 +157,26 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with max query result rows example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
MaxQueryResultRows: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -220,6 +243,59 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, "test-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg bigquery.Config
|
||||
want int
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-default",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
want: 50,
|
||||
},
|
||||
{
|
||||
desc: "configured value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-configured",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
MaxQueryResultRows: 100,
|
||||
},
|
||||
want: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
bqSrc, ok := src.(*bigquery.Source)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *bigquery.Source, got %T", src)
|
||||
}
|
||||
if bqSrc.MaxQueryResultRows != tc.want {
|
||||
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user