mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat: Add MongoDB find Tools (#970)
Add MongoDB Tools: - mongodb-find - mongodb-find-one --------- Co-authored-by: Author: Dennis Geurts <dennisg@dennisg.nl> Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
@@ -69,6 +69,8 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
|
||||
7
docs/en/resources/tools/mongodb/_index.md
Normal file
7
docs/en/resources/tools/mongodb/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "MongoDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with the MongoDB Source.
|
||||
---
|
||||
62
docs/en/resources/tools/mongodb/mongodb-find-one.md
Normal file
62
docs/en/resources/tools/mongodb/mongodb-find-one.md
Normal file
@@ -0,0 +1,62 @@
|
||||
---
|
||||
title: "mongodb-find-one"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "mongodb-find-one" tool finds and retrieves a single document from a MongoDB collection.
|
||||
aliases:
|
||||
- /resources/tools/mongodb-find-one
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `mongodb-find-one` tool is used to retrieve the **first single document** that matches a specified filter from a MongoDB collection. If multiple documents match the filter, you can use `sort` options to control which document is returned. Otherwise, the selection is not guaranteed.
|
||||
|
||||
The tool returns a single JSON object representing the document, wrapped in a JSON array.
|
||||
|
||||
This tool is compatible with the following source kind:
|
||||
|
||||
* [`mongodb`](../../sources/mongodb.md)
|
||||
|
||||
---
|
||||
|
||||
## Example
|
||||
|
||||
Here's a common use case: finding a specific user by their unique email address and returning their profile information, while excluding sensitive fields like the password hash.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_user_profile:
|
||||
kind: mongodb-find-one
|
||||
source: my-mongo-source
|
||||
description: Retrieves a user's profile by their email address.
|
||||
database: user_data
|
||||
collection: profiles
|
||||
filterPayload: |
|
||||
{ "email": {{json .email}} }
|
||||
filterParams:
|
||||
- name: email
|
||||
type: string
|
||||
description: The email address of the user to find.
|
||||
projectPayload: |
|
||||
{
|
||||
"password_hash": 0,
|
||||
"login_history": 0
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|:---------------|:---------|:-------------|:---------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be `mongodb-find-one`. |
|
||||
| source | string | true | The name of the `mongodb` source to use. |
|
||||
| description | string | true | A description of the tool that is passed to the LLM. |
|
||||
| database | string | true | The name of the MongoDB database to query. |
|
||||
| collection | string | true | The name of the MongoDB collection to query. |
|
||||
| filterPayload | string | true | The MongoDB query filter document to select the document. Uses `{{json .param_name}}` for templating. |
|
||||
| filterParams | list | true | A list of parameter objects that define the variables used in the `filterPayload`. |
|
||||
| projectPayload | string | false | An optional MongoDB projection document to specify which fields to include (1) or exclude (0) in the result. |
|
||||
| projectParams | list | false | A list of parameter objects for the `projectPayload`. |
|
||||
| sortPayload | string | false | An optional MongoDB sort document. Useful for selecting which document to return if the filter matches multiple (e.g., get the most recent). |
|
||||
| sortParams | list | false | A list of parameter objects for the `sortPayload`. |
|
||||
70
docs/en/resources/tools/mongodb/mongodb-find.md
Normal file
70
docs/en/resources/tools/mongodb/mongodb-find.md
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: "mongodb-find"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "mongodb-find" tool finds and retrieves documents from a MongoDB collection.
|
||||
aliases:
|
||||
- /resources/tools/mongodb-find
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `mongodb-find` tool is used to query a MongoDB collection and retrieve documents that match a specified filter. It's a flexible tool that allows you to shape the output by selecting specific fields (**projection**), ordering the results (**sorting**), and restricting the number of documents returned (**limiting**).
|
||||
|
||||
The tool returns a JSON array of the documents found.
|
||||
|
||||
This tool is compatible with the following source kind:
|
||||
|
||||
* [`mongodb`](../../sources/mongodb.md)
|
||||
|
||||
## Example
|
||||
|
||||
Here's an example that finds up to 10 users from the `customers` collection who live in a specific city. The results are sorted by their last name, and only their first name, last name, and email are returned.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
find_local_customers:
|
||||
kind: mongodb-find
|
||||
source: my-mongo-source
|
||||
description: Finds customers by city, sorted by last name.
|
||||
database: crm
|
||||
collection: customers
|
||||
limit: 10
|
||||
filterPayload: |
|
||||
{ "address.city": {{json .city}} }
|
||||
filterParams:
|
||||
- name: city
|
||||
type: string
|
||||
description: The city to search for customers in.
|
||||
projectPayload: |
|
||||
{
|
||||
"first_name": 1,
|
||||
"last_name": 1,
|
||||
"email": 1,
|
||||
"_id": 0
|
||||
}
|
||||
sortPayload: |
|
||||
{ "last_name": {{json .sort_order}} }
|
||||
sortParams:
|
||||
- name: sort_order
|
||||
type: integer
|
||||
description: The sort order (1 for ascending, -1 for descending).
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|:---------------|:---------|:-------------|:----------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be `mongodb-find`. |
|
||||
| source | string | true | The name of the `mongodb` source to use. |
|
||||
| description | string | true | A description of the tool that is passed to the LLM. |
|
||||
| database | string | true | The name of the MongoDB database to query. |
|
||||
| collection | string | true | The name of the MongoDB collection to query. |
|
||||
| filterPayload | string | true | The MongoDB query filter document to select which documents to return. Uses `{{json .param_name}}` for templating. |
|
||||
| filterParams | list | true | A list of parameter objects that define the variables used in the `filterPayload`. |
|
||||
| projectPayload | string | false | An optional MongoDB projection document to specify which fields to include (1) or exclude (0) in the results. |
|
||||
| projectParams | list | false | A list of parameter objects for the `projectPayload`. |
|
||||
| sortPayload | string | false | An optional MongoDB sort document to define the order of the returned documents. Use 1 for ascending and -1 for descending. |
|
||||
| sortParams | list | false | A list of parameter objects for the `sortPayload`. |
|
||||
| limit | integer | false | An optional integer specifying the maximum number of documents to return. |
|
||||
@@ -15,8 +15,11 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]*$`)
|
||||
@@ -25,6 +28,7 @@ func IsValidName(s string) bool {
|
||||
return validName.MatchString(s)
|
||||
}
|
||||
|
||||
// ConvertAnySliceToTyped a []any to typed slice ([]string, []int, []float etc.)
|
||||
func ConvertAnySliceToTyped(s []any, itemType string) (any, error) {
|
||||
var typedSlice any
|
||||
switch itemType {
|
||||
@@ -71,3 +75,44 @@ func ConvertAnySliceToTyped(s []any, itemType string) (any, error) {
|
||||
}
|
||||
return typedSlice, nil
|
||||
}
|
||||
|
||||
// convertParamToJSON is a Go template helper function to convert a parameter to JSON formatted string.
|
||||
func convertParamToJSON(param any) (string, error) {
|
||||
jsonData, err := json.Marshal(param)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal param to JSON: %w", err)
|
||||
}
|
||||
return string(jsonData), nil
|
||||
}
|
||||
|
||||
// PopulateTemplateWithJSON populate a Go template with a custom `json` array formatter
|
||||
func PopulateTemplateWithJSON(templateName, templateString string, data map[string]any) (string, error) {
|
||||
funcMap := template.FuncMap{
|
||||
"json": convertParamToJSON,
|
||||
}
|
||||
|
||||
tmpl, err := template.New(templateName).Funcs(funcMap).Parse(templateString)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing template '%s': %w", templateName, err)
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
err = tmpl.Execute(&result, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing template '%s': %w", templateName, err)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
func CheckDuplicateParameters(ps Parameters) error {
|
||||
seenNames := make(map[string]bool)
|
||||
for _, p := range ps {
|
||||
pName := p.GetName()
|
||||
if _, exists := seenNames[pName]; exists {
|
||||
return fmt.Errorf("parameter name must be unique across all parameter fields. Duplicate parameter: %s", pName)
|
||||
}
|
||||
seenNames[pName] = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,6 +105,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across queryParams, bodyParams, and headerParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
pathMcpManifest := cfg.PathParams.McpManifest()
|
||||
queryMcpManifest := cfg.QueryParams.McpManifest()
|
||||
bodyMcpManifest := cfg.BodyParams.McpManifest()
|
||||
@@ -143,15 +153,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across queryParams, bodyParams, and headerParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
@@ -207,15 +208,6 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// helper function to convert a parameter to JSON formatted string.
|
||||
func convertParamToJSON(param any) (string, error) {
|
||||
jsonData, err := json.Marshal(param)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal param to JSON: %w", err)
|
||||
}
|
||||
return string(jsonData), nil
|
||||
}
|
||||
|
||||
// Helper function to generate the HTTP request body upon Tool invocation.
|
||||
func getRequestBody(bodyParams tools.Parameters, requestBodyPayload string, paramsMap map[string]any) (string, error) {
|
||||
bodyParamValues, err := tools.GetParams(bodyParams, paramsMap)
|
||||
@@ -224,20 +216,11 @@ func getRequestBody(bodyParams tools.Parameters, requestBodyPayload string, para
|
||||
}
|
||||
bodyParamsMap := bodyParamValues.AsMap()
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": convertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("body").Funcs(funcMap).Parse(requestBodyPayload)
|
||||
requestBodyStr, err := tools.PopulateTemplateWithJSON("HTTPToolRequestBody", requestBodyPayload, bodyParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing request body: %s", err)
|
||||
return "", err
|
||||
}
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, bodyParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error replacing body payload: %s", err)
|
||||
}
|
||||
return result.String(), nil
|
||||
return requestBodyStr, nil
|
||||
}
|
||||
|
||||
// Helper function to generate the HTTP request URL upon Tool invocation.
|
||||
|
||||
244
internal/tools/mongodb/mongodbfind/mongodbfind.go
Normal file
244
internal/tools/mongodb/mongodbfind/mongodbfind.go
Normal file
@@ -0,0 +1,244 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbfind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-find"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
Limit int64 `yaml:"limit"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams)
|
||||
|
||||
// Verify no duplicate parameter names
|
||||
err := tools.CheckDuplicateParameters(allParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create Toolbox manifest
|
||||
paramManifest := allParameters.Manifest()
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
// Create MCP manifest
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: allParameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
ProjectPayload: cfg.ProjectPayload,
|
||||
ProjectParams: cfg.ProjectParams,
|
||||
SortPayload: cfg.SortPayload,
|
||||
SortParams: cfg.SortParams,
|
||||
Limit: cfg.Limit,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
Limit int64 `yaml:"limit"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getOptions(sortParameters tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
|
||||
opts := options.Find()
|
||||
|
||||
sort := bson.M{}
|
||||
for _, p := range sortParameters {
|
||||
sort[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
opts = opts.SetSort(sort)
|
||||
|
||||
if len(projectPayload) == 0 {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
result, err := tools.PopulateTemplateWithJSON("MongoDBFindProjectString", projectPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating project payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.UnmarshalExtJSON([]byte(result), false, &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
|
||||
opts = opts.SetProjection(projection)
|
||||
|
||||
if limit > 0 {
|
||||
opts = opts.SetLimit(limit)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := tools.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts, err := getOptions(t.SortParams, t.ProjectPayload, t.Limit, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating options: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cur, err := t.database.Collection(t.Collection).Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
|
||||
var data = []any{}
|
||||
err = cur.All(context.TODO(), &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
for _, item := range data {
|
||||
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
}
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
150
internal/tools/mongodb/mongodbfind/mongodbfind_test.go
Normal file
150
internal/tools/mongodb/mongodbfind/mongodbfind_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbfind_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
projectPayload: |
|
||||
{ name: 1, age: 1 }
|
||||
projectParams: []
|
||||
sortPayload: |
|
||||
{ timestamp: -1 }
|
||||
sortParams: []
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbfind.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-find",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
ProjectPayload: "{ name: 1, age: 1 }\n",
|
||||
ProjectParams: tools.Parameters{},
|
||||
SortPayload: "{ timestamp: -1 }\n",
|
||||
SortParams: tools.Parameters{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-find"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
234
internal/tools/mongodb/mongodbfindone/mongodbfindone.go
Normal file
234
internal/tools/mongodb/mongodbfindone/mongodbfindone.go
Normal file
@@ -0,0 +1,234 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbfindone
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-find-one"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams)
|
||||
|
||||
// Verify no duplicate parameter names
|
||||
err := tools.CheckDuplicateParameters(allParameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create Toolbox manifest
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
// Create MCP manifest
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: allParameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
ProjectPayload: cfg.ProjectPayload,
|
||||
ProjectParams: cfg.ProjectParams,
|
||||
SortPayload: cfg.SortPayload,
|
||||
SortParams: cfg.SortParams,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getOptions(sortParameters tools.Parameters, projectPayload string, paramsMap map[string]any) (*options.FindOneOptions, error) {
|
||||
opts := options.FindOne()
|
||||
|
||||
sort := bson.M{}
|
||||
for _, p := range sortParameters {
|
||||
sort[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
opts = opts.SetSort(sort)
|
||||
|
||||
if len(projectPayload) == 0 {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
result, err := tools.PopulateTemplateWithJSON("MongoDBFindOneProjectString", projectPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating project payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.Unmarshal([]byte(result), &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
opts = opts.SetProjection(projection)
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := tools.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts, err := getOptions(t.SortParams, t.ProjectPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating options: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := t.database.Collection(t.Collection).FindOne(ctx, filter, opts)
|
||||
if res.Err() != nil {
|
||||
return nil, res.Err()
|
||||
}
|
||||
|
||||
var data any
|
||||
err = res.Decode(&data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
tmp, _ := bson.MarshalExtJSON(data, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
150
internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go
Normal file
150
internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbfindone_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
projectPayload: |
|
||||
{ name: 1, age: 1 }
|
||||
projectParams: []
|
||||
sortPayload: |
|
||||
{ timestamp: -1 }
|
||||
sortParams: []
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbfindone.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-find-one",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
ProjectPayload: "{ name: 1, age: 1 }\n",
|
||||
ProjectParams: tools.Parameters{},
|
||||
SortPayload: "{ timestamp: -1 }\n",
|
||||
SortParams: tools.Parameters{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-find-one"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user