mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
fix(tool/mongodb-find): fix find tool limit field (#1570)
The projection checking block in `getOptions` exists early previously, resulting in the limit not being set. 🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/1491
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
@@ -54,7 +55,7 @@ type Config struct {
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
@@ -91,6 +92,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify 'limit' value
|
||||
if cfg.Limit <= 0 {
|
||||
return nil, fmt.Errorf("limit must be a positive number, but got %d", cfg.Limit)
|
||||
}
|
||||
|
||||
// Create Toolbox manifest
|
||||
paramManifest := allParameters.Manifest()
|
||||
if paramManifest == nil {
|
||||
@@ -147,7 +153,12 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getOptions(sortParameters tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
|
||||
func getOptions(ctx context.Context, sortParameters tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
opts := options.Find()
|
||||
|
||||
sort := bson.M{}
|
||||
@@ -156,28 +167,28 @@ func getOptions(sortParameters tools.Parameters, projectPayload string, limit in
|
||||
}
|
||||
opts = opts.SetSort(sort)
|
||||
|
||||
if len(projectPayload) == 0 {
|
||||
return opts, nil
|
||||
if len(projectPayload) > 0{
|
||||
|
||||
result, err := tools.PopulateTemplateWithJSON("MongoDBFindProjectString", projectPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating project payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.UnmarshalExtJSON([]byte(result), false, &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
|
||||
opts = opts.SetProjection(projection)
|
||||
logger.DebugContext(ctx, "Projection is set to %v", projection)
|
||||
}
|
||||
|
||||
result, err := tools.PopulateTemplateWithJSON("MongoDBFindProjectString", projectPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating project payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.UnmarshalExtJSON([]byte(result), false, &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
|
||||
opts = opts.SetProjection(projection)
|
||||
|
||||
if limit > 0 {
|
||||
opts = opts.SetLimit(limit)
|
||||
logger.DebugContext(ctx, "Limit is being set to %d", limit)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
@@ -189,12 +200,12 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts, err := getOptions(t.SortParams, t.ProjectPayload, t.Limit, paramsMap)
|
||||
|
||||
opts, err := getOptions(ctx, t.SortParams, t.ProjectPayload, t.Limit, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating options: %s", err)
|
||||
}
|
||||
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
|
||||
@@ -119,22 +119,22 @@ func TestMongoDBToolEndpoints(t *testing.T) {
|
||||
|
||||
delete1Want := "1"
|
||||
deleteManyWant := "2"
|
||||
RunToolDeleteInvokeTest(t, delete1Want, deleteManyWant)
|
||||
runToolDeleteInvokeTest(t, delete1Want, deleteManyWant)
|
||||
|
||||
insert1Want := `["68666e1035bb36bf1b4d47fb"]`
|
||||
insertManyWant := `["68667a6436ec7d0363668db7","68667a6436ec7d0363668db8","68667a6436ec7d0363668db9"]`
|
||||
RunToolInsertInvokeTest(t, insert1Want, insertManyWant)
|
||||
runToolInsertInvokeTest(t, insert1Want, insertManyWant)
|
||||
|
||||
update1Want := "1"
|
||||
updateManyWant := "[2,0,2]"
|
||||
RunToolUpdateInvokeTest(t, update1Want, updateManyWant)
|
||||
runToolUpdateInvokeTest(t, update1Want, updateManyWant)
|
||||
|
||||
aggregate1Want := `[{"id":2}]`
|
||||
aggregateManyWant := `[{"id":500},{"id":501}]`
|
||||
RunToolAggregateInvokeTest(t, aggregate1Want, aggregateManyWant)
|
||||
runToolAggregateInvokeTest(t, aggregate1Want, aggregateManyWant)
|
||||
}
|
||||
|
||||
func RunToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
|
||||
func runToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -207,7 +207,7 @@ func RunToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
|
||||
func runToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -280,7 +280,7 @@ func RunToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
|
||||
func runToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -352,7 +352,7 @@ func RunToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
|
||||
})
|
||||
}
|
||||
}
|
||||
func RunToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) {
|
||||
func runToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -446,6 +446,7 @@ func setupMongoDB(t *testing.T, ctx context.Context, database *mongo.Database) f
|
||||
|
||||
documents := []map[string]any{
|
||||
{"_id": 1, "id": 1, "name": "Alice", "email": ServiceAccountEmail},
|
||||
{"_id": 1, "id": 2, "name": "FakeAlice", "email": "fakeAlice@gmail.com"},
|
||||
{"_id": 2, "id": 2, "name": "Jane"},
|
||||
{"_id": 3, "id": 3, "name": "Sid"},
|
||||
{"_id": 4, "id": 4, "name": nil},
|
||||
@@ -497,6 +498,8 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str
|
||||
"filterParams": []any{},
|
||||
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
"limit": 1,
|
||||
"sort": `{ "id": 1 }`,
|
||||
},
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
|
||||
Reference in New Issue
Block a user