fix(tool/mongodb-find): fix find tool limit field (#1570)

The projection checking block in `getOptions` exists early previously,
resulting in the limit not being set.

🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/1491
This commit is contained in:
Wenxin Du
2025-09-25 18:16:48 -04:00
committed by GitHub
parent 8dfcbfd5b3
commit 4166bf7ab8
2 changed files with 44 additions and 30 deletions

View File

@@ -21,6 +21,7 @@ import (
"github.com/goccy/go-yaml"
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
"github.com/googleapis/genai-toolbox/internal/util"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
@@ -54,7 +55,7 @@ type Config struct {
Database string `yaml:"database" validate:"required"`
Collection string `yaml:"collection" validate:"required"`
FilterPayload string `yaml:"filterPayload" validate:"required"`
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
FilterParams tools.Parameters `yaml:"filterParams"`
ProjectPayload string `yaml:"projectPayload"`
ProjectParams tools.Parameters `yaml:"projectParams"`
SortPayload string `yaml:"sortPayload"`
@@ -91,6 +92,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, err
}
// Verify 'limit' value
if cfg.Limit <= 0 {
return nil, fmt.Errorf("limit must be a positive number, but got %d", cfg.Limit)
}
// Create Toolbox manifest
paramManifest := allParameters.Manifest()
if paramManifest == nil {
@@ -147,7 +153,12 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func getOptions(sortParameters tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
func getOptions(ctx context.Context, sortParameters tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
panic(err)
}
opts := options.Find()
sort := bson.M{}
@@ -156,28 +167,28 @@ func getOptions(sortParameters tools.Parameters, projectPayload string, limit in
}
opts = opts.SetSort(sort)
if len(projectPayload) == 0 {
return opts, nil
if len(projectPayload) > 0{
result, err := tools.PopulateTemplateWithJSON("MongoDBFindProjectString", projectPayload, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating project payload: %s", err)
}
var projection any
err = bson.UnmarshalExtJSON([]byte(result), false, &projection)
if err != nil {
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
}
opts = opts.SetProjection(projection)
logger.DebugContext(ctx, "Projection is set to %v", projection)
}
result, err := tools.PopulateTemplateWithJSON("MongoDBFindProjectString", projectPayload, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating project payload: %s", err)
}
var projection any
err = bson.UnmarshalExtJSON([]byte(result), false, &projection)
if err != nil {
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
}
opts = opts.SetProjection(projection)
if limit > 0 {
opts = opts.SetLimit(limit)
logger.DebugContext(ctx, "Limit is being set to %d", limit)
}
return opts, nil
}
@@ -189,12 +200,12 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
if err != nil {
return nil, fmt.Errorf("error populating filter: %s", err)
}
opts, err := getOptions(t.SortParams, t.ProjectPayload, t.Limit, paramsMap)
opts, err := getOptions(ctx, t.SortParams, t.ProjectPayload, t.Limit, paramsMap)
if err != nil {
return nil, fmt.Errorf("error populating options: %s", err)
}
var filter = bson.D{}
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {

View File

@@ -119,22 +119,22 @@ func TestMongoDBToolEndpoints(t *testing.T) {
delete1Want := "1"
deleteManyWant := "2"
RunToolDeleteInvokeTest(t, delete1Want, deleteManyWant)
runToolDeleteInvokeTest(t, delete1Want, deleteManyWant)
insert1Want := `["68666e1035bb36bf1b4d47fb"]`
insertManyWant := `["68667a6436ec7d0363668db7","68667a6436ec7d0363668db8","68667a6436ec7d0363668db9"]`
RunToolInsertInvokeTest(t, insert1Want, insertManyWant)
runToolInsertInvokeTest(t, insert1Want, insertManyWant)
update1Want := "1"
updateManyWant := "[2,0,2]"
RunToolUpdateInvokeTest(t, update1Want, updateManyWant)
runToolUpdateInvokeTest(t, update1Want, updateManyWant)
aggregate1Want := `[{"id":2}]`
aggregateManyWant := `[{"id":500},{"id":501}]`
RunToolAggregateInvokeTest(t, aggregate1Want, aggregateManyWant)
runToolAggregateInvokeTest(t, aggregate1Want, aggregateManyWant)
}
func RunToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
func runToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
@@ -207,7 +207,7 @@ func RunToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
}
}
func RunToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
func runToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
@@ -280,7 +280,7 @@ func RunToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
}
}
func RunToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
func runToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
@@ -352,7 +352,7 @@ func RunToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
})
}
}
func RunToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) {
func runToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
@@ -446,6 +446,7 @@ func setupMongoDB(t *testing.T, ctx context.Context, database *mongo.Database) f
documents := []map[string]any{
{"_id": 1, "id": 1, "name": "Alice", "email": ServiceAccountEmail},
{"_id": 1, "id": 2, "name": "FakeAlice", "email": "fakeAlice@gmail.com"},
{"_id": 2, "id": 2, "name": "Jane"},
{"_id": 3, "id": 3, "name": "Sid"},
{"_id": 4, "id": 4, "name": nil},
@@ -497,6 +498,8 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str
"filterParams": []any{},
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
"database": MongoDbDatabase,
"limit": 1,
"sort": `{ "id": 1 }`,
},
"my-tool": map[string]any{
"kind": toolKind,