mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
Compare commits
9 Commits
ci-js-test
...
mongodb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ff819386a | ||
|
|
6119d401f2 | ||
|
|
967cd98cb0 | ||
|
|
2a94a33d63 | ||
|
|
f93693c92d | ||
|
|
9903df0715 | ||
|
|
ed6d6b8e4a | ||
|
|
b261be23a1 | ||
|
|
61b08a345e |
@@ -425,7 +425,7 @@ steps:
|
||||
"Valkey" \
|
||||
valkey \
|
||||
valkey
|
||||
|
||||
|
||||
- id: "firestore"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
@@ -466,8 +466,26 @@ steps:
|
||||
"Looker" \
|
||||
looker \
|
||||
looker
|
||||
|
||||
|
||||
- id: "mongodb"
|
||||
name : golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
- "MONGODB_DATABASE=$_MONGODB_DATABASE"
|
||||
secretEnv: ["CLIENT_ID", "MONGODB_URI"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"MongoDB" \
|
||||
mongodb \
|
||||
mongodb
|
||||
|
||||
- id: "alloydbwaitforoperation"
|
||||
name: golang:1
|
||||
@@ -546,6 +564,8 @@ availableSecrets:
|
||||
env: LOOKER_CLIENT_ID
|
||||
- versionName: projects/107716898620/secrets/looker_client_secret/versions/latest
|
||||
env: LOOKER_CLIENT_SECRET
|
||||
- versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest
|
||||
env: MONGODB_URI
|
||||
|
||||
|
||||
options:
|
||||
@@ -579,3 +599,4 @@ substitutions:
|
||||
_COUCHBASE_BUCKET: "couchbase-bucket"
|
||||
_COUCHBASE_SCOPE: "couchbase-scope"
|
||||
_LOOKER_VERIFY_SSL: "true"
|
||||
_MONGODB_DATABASE: "test"
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,6 +4,9 @@
|
||||
# vscode
|
||||
.vscode/
|
||||
|
||||
# idea
|
||||
.idea/
|
||||
|
||||
# npm
|
||||
node_modules
|
||||
|
||||
|
||||
10
cmd/root.go
10
cmd/root.go
@@ -69,6 +69,15 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
@@ -98,6 +107,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
|
||||
34
docs/en/resources/sources/mongodb.md
Normal file
34
docs/en/resources/sources/mongodb.md
Normal file
@@ -0,0 +1,34 @@
|
||||
---
|
||||
title: "MongoDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
MongoDB is a no-sql data platform that can not only serve general purpose data requirements also perform VectorSearch where both operational data and embeddings used of search can reside in same document.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[MongoDB][mongodb-docs] is a leading nosql database that can not only cater your operational data needs but also perform vector search.
|
||||
|
||||
[mongodb-docs]: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-mongodb:
|
||||
kind: mongodb
|
||||
uri: "mongodb+srv://username:password@host.mongodb.net"
|
||||
database: sample_mflix
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|-------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "mongodb". |
|
||||
| uri | string | true | connection string to connect to MongoDB | |
|
||||
| database | string | true | Name of the mongodb database to connect to (e.g. "sample_mflix"). |
|
||||
5
go.mod
5
go.mod
@@ -33,6 +33,7 @@ require (
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/thlib/go-timezone-local v0.0.7
|
||||
github.com/valkey-io/valkey-go v1.0.63
|
||||
go.mongodb.org/mongo-driver v1.17.4
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.62.0
|
||||
go.opentelemetry.io/otel v1.37.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.37.0
|
||||
@@ -106,12 +107,16 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.18 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.1.2 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/zeebo/errs v1.4.0 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -1029,6 +1029,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.1 h1:RKWQW7wTgYAY2fU9S+9LaJ9OwRPbRc0I17tlT7nDmAY=
|
||||
@@ -1102,6 +1104,12 @@ github.com/thlib/go-timezone-local v0.0.7 h1:fX8zd3aJydqLlTs/TrROrIIdztzsdFV23Oz
|
||||
github.com/thlib/go-timezone-local v0.0.7/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI=
|
||||
github.com/valkey-io/valkey-go v1.0.63 h1:LNlDTcUxy9jxrmGHSvd0s/NsgEmQbvREYvvBAHCIir0=
|
||||
github.com/valkey-io/valkey-go v1.0.63/go.mod h1:bHmwjIEOrGq/ubOJfh5uMRs7Xj6mV3mQ/ZXUbmqpjqY=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -1117,6 +1125,8 @@ github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM=
|
||||
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
|
||||
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
|
||||
106
internal/sources/mongodb/mongodb.go
Normal file
106
internal/sources/mongodb/mongodb.go
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "mongodb"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Uri string `yaml:"uri" validate:"required"` // MongoDB Atlas connection URI
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
client, err := initMongoDBClient(ctx, tracer, r.Name, r.Uri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create MongoDB client: %w", err)
|
||||
}
|
||||
|
||||
// Verify the connection
|
||||
err = client.Ping(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *mongo.Client
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) MongoClient() *mongo.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
||||
// Start a tracing span
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Create a new MongoDB client
|
||||
clientOpts := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create MongoDB client: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
111
internal/sources/mongodb/mongodb_test.go
Normal file
111
internal/sources/mongodb/mongodb_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoDB(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: "mongodb"
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"mongo-db": mongodb.Config{
|
||||
Name: "mongo-db",
|
||||
Kind: mongodb.SourceKind,
|
||||
Uri: "mongodb+srv://username:password@host/dbname",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: mongodb
|
||||
uri: "mongodb+srv://username:password@host/dbname"
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"mongo-db\" as \"mongodb\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: mongodb\n 3 | uri: mongodb+srv://username:password@host/dbname",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
mongo-db:
|
||||
kind: mongodb
|
||||
`,
|
||||
err: "unable to parse source \"mongo-db\" as \"mongodb\": Key: 'Config.Uri' Error:Field validation for 'Uri' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got \n%q, want \n%q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
75
internal/tools/mongodb/common/common.go
Normal file
75
internal/tools/mongodb/common/common.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"text/template"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// helper function to convert a parameter to JSON formatted string.
|
||||
func ConvertParamToJSON(param any) (string, error) {
|
||||
jsonData, err := json.Marshal(param)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal param to JSON: %w", err)
|
||||
}
|
||||
return string(jsonData), nil
|
||||
}
|
||||
|
||||
func ParsePayloadTemplate(params tools.Parameters, payload string, paramsMap map[string]any) (string, error) {
|
||||
// Create a map for request body parameters
|
||||
cleanParamsMap := make(map[string]any)
|
||||
for _, p := range params {
|
||||
k := p.GetName()
|
||||
v, ok := paramsMap[k]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing parameter %s", k)
|
||||
}
|
||||
cleanParamsMap[k] = v
|
||||
}
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": ConvertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("template").Funcs(funcMap).Parse(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing: %s", err)
|
||||
}
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, cleanParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error replacing payload: %s", err)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
func GetUpdate(params tools.Parameters, payload string, paramsMap map[string]any) (string, error) {
|
||||
// Create a map for request body parameters
|
||||
cleanParamsMap := make(map[string]any)
|
||||
for _, p := range params {
|
||||
k := p.GetName()
|
||||
v, ok := paramsMap[k]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing update parameter %s", k)
|
||||
}
|
||||
cleanParamsMap[k] = v
|
||||
}
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": ConvertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("filter").Funcs(funcMap).Parse(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing filter: %s", err)
|
||||
}
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, cleanParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error replacing filter payload: %s", err)
|
||||
}
|
||||
return result.String(), nil
|
||||
}
|
||||
237
internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go
Normal file
237
internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbaggregate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-aggregate"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
PipelinePayload string `yaml:"pipelinePayload" validate:"required"`
|
||||
PipelineParams tools.Parameters `yaml:"pipelineParams" validate:"required"`
|
||||
Canonical bool `yaml:"canonical"`
|
||||
ReadOnly bool `yaml:"readOnly"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.PipelineParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.PipelineParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
pipelineMcpManifest := cfg.PipelineParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
pipelineMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range pipelineMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across pipelineParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
PipelinePayload: cfg.PipelinePayload,
|
||||
PipelineParams: cfg.PipelineParams,
|
||||
Canonical: cfg.Canonical,
|
||||
ReadOnly: cfg.ReadOnly,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Collection string `yaml:"collection"`
|
||||
PipelinePayload string `yaml:"pipelinePayload"`
|
||||
PipelineParams tools.Parameters `yaml:"pipelineParams"`
|
||||
Canonical bool `yaml:"canonical"`
|
||||
ReadOnly bool `yaml:"readOnly"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
pipelineString, err := common.ParsePayloadTemplate(t.PipelineParams, t.PipelinePayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating pipeline: %s", err)
|
||||
}
|
||||
|
||||
var pipeline = []bson.M{}
|
||||
err = bson.UnmarshalExtJSON([]byte(pipelineString), t.Canonical, &pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.ReadOnly {
|
||||
//fail if we do a merge or an out
|
||||
for _, stage := range pipeline {
|
||||
for key := range stage {
|
||||
if key == "$merge" || key == "$out" {
|
||||
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur, err := t.database.Collection(t.Collection).Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
|
||||
var data = []any{}
|
||||
err = cur.All(ctx, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
var final []any
|
||||
for _, item := range data {
|
||||
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
}
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
141
internal/tools/mongodb/mongodbaggregate/mongodbaggregate_test.go
Normal file
141
internal/tools/mongodb/mongodbaggregate/mongodbaggregate_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbaggregate_test
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-aggregate
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
readOnly: true
|
||||
pipelinePayload: |
|
||||
[{ $match: { name: {{json .name}} }}]
|
||||
pipelineParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbaggregate.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-aggregate",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
PipelinePayload: "[{ $match: { name: {{json .name}} }}]\n",
|
||||
PipelineParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
ReadOnly: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-aggregate
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
pipelinePayload: |
|
||||
[{ $match: { name : {{json .name}} }}]
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-aggregate"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
207
internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go
Normal file
207
internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go
Normal file
@@ -0,0 +1,207 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbdeletemany
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-delete-many"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
)
|
||||
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts := options.Delete()
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).DeleteMany(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.DeletedCount == 0 {
|
||||
return nil, errors.New("no document found")
|
||||
}
|
||||
|
||||
// not much to return actually
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbdeletemany_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-delete-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbdeletemany.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-delete-many",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-delete-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-delete-many"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
202
internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go
Normal file
202
internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go
Normal file
@@ -0,0 +1,202 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbdeleteone
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-delete-one"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
)
|
||||
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts := options.Delete()
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).DeleteOne(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// do not return an error when the count is 0, to mirror the delete many call result
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
140
internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone_test.go
Normal file
140
internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbdeleteone_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-delete-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbdeleteone.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-delete-one",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-delete-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-delete-one"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
300
internal/tools/mongodb/mongodbfind/mongodbfind.go
Normal file
300
internal/tools/mongodb/mongodbfind/mongodbfind.go
Normal file
@@ -0,0 +1,300 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbfind
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"text/template"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-find"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
Limit int64 `yaml:"limit"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
cfg.ProjectParams.Manifest(),
|
||||
cfg.SortParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
projectMcpManifest := cfg.ProjectParams.McpManifest()
|
||||
sortMcpManifest := cfg.SortParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
projectMcpManifest.Required,
|
||||
sortMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range projectMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range sortMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
ProjectPayload: cfg.ProjectPayload,
|
||||
ProjectParams: cfg.ProjectParams,
|
||||
SortPayload: cfg.SortPayload,
|
||||
SortParams: cfg.SortParams,
|
||||
Limit: cfg.Limit,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
Limit int64 `yaml:"limit"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getOptions(sortParameters tools.Parameters, projectParams tools.Parameters, projectPayload string, limit int64, paramsMap map[string]any) (*options.FindOptions, error) {
|
||||
opts := options.Find()
|
||||
|
||||
sort := bson.M{}
|
||||
for _, p := range sortParameters {
|
||||
sort[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
opts = opts.SetSort(sort)
|
||||
|
||||
if len(projectPayload) == 0 {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
project := bson.M{}
|
||||
|
||||
for _, p := range projectParams {
|
||||
project[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": common.ConvertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("project").Funcs(funcMap).Parse(projectPayload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing project: %s", err)
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, project)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error replacing projection payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.UnmarshalExtJSON(result.Bytes(), false, &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
|
||||
opts = opts.SetProjection(projection)
|
||||
|
||||
if limit > 0 {
|
||||
opts = opts.SetLimit(limit)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts, err := getOptions(t.SortParams, t.ProjectParams, t.ProjectPayload, t.Limit, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating options: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cur, err := t.database.Collection(t.Collection).Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
|
||||
var data = []any{}
|
||||
err = cur.All(context.TODO(), &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
for _, item := range data {
|
||||
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
}
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
150
internal/tools/mongodb/mongodbfind/mongodbfind_test.go
Normal file
150
internal/tools/mongodb/mongodbfind/mongodbfind_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbfind_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
projectPayload: |
|
||||
{ name: 1, age: 1 }
|
||||
projectParams: []
|
||||
sortPayload: |
|
||||
{ timestamp: -1 }
|
||||
sortParams: []
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbfind.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-find",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
ProjectPayload: "{ name: 1, age: 1 }\n",
|
||||
ProjectParams: tools.Parameters{},
|
||||
SortPayload: "{ timestamp: -1 }\n",
|
||||
SortParams: tools.Parameters{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-find"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
288
internal/tools/mongodb/mongodbfindone/mongodbfindone.go
Normal file
288
internal/tools/mongodb/mongodbfindone/mongodbfindone.go
Normal file
@@ -0,0 +1,288 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbfindone
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"text/template"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-find-one"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
cfg.ProjectParams.Manifest(),
|
||||
cfg.SortParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
projectMcpManifest := cfg.ProjectParams.McpManifest()
|
||||
sortMcpManifest := cfg.SortParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
projectMcpManifest.Required,
|
||||
sortMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range projectMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range sortMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
ProjectPayload: cfg.ProjectPayload,
|
||||
ProjectParams: cfg.ProjectParams,
|
||||
SortPayload: cfg.SortPayload,
|
||||
SortParams: cfg.SortParams,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams"`
|
||||
ProjectPayload string `yaml:"projectPayload"`
|
||||
ProjectParams tools.Parameters `yaml:"projectParams"`
|
||||
SortPayload string `yaml:"sortPayload"`
|
||||
SortParams tools.Parameters `yaml:"sortParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getOptions(sortParameters tools.Parameters, projectParams tools.Parameters, projectPayload string, paramsMap map[string]any) (*options.FindOneOptions, error) {
|
||||
opts := options.FindOne()
|
||||
|
||||
sort := bson.M{}
|
||||
for _, p := range sortParameters {
|
||||
sort[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
opts = opts.SetSort(sort)
|
||||
|
||||
if len(projectPayload) == 0 {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
project := bson.M{}
|
||||
for _, p := range projectParams {
|
||||
project[p.GetName()] = paramsMap[p.GetName()]
|
||||
}
|
||||
|
||||
// Create a FuncMap to format array parameters
|
||||
funcMap := template.FuncMap{
|
||||
"json": common.ConvertParamToJSON,
|
||||
}
|
||||
templ, err := template.New("project").Funcs(funcMap).Parse(projectPayload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing project: %s", err)
|
||||
}
|
||||
|
||||
var result bytes.Buffer
|
||||
err = templ.Execute(&result, project)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error replacing project payload: %s", err)
|
||||
}
|
||||
|
||||
var projection any
|
||||
err = bson.Unmarshal(result.Bytes(), &projection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling projection: %s", err)
|
||||
}
|
||||
opts = opts.SetProjection(projection)
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
opts, err := getOptions(t.SortParams, t.ProjectParams, t.ProjectPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating options: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := t.database.Collection(t.Collection).FindOne(ctx, filter, opts)
|
||||
if res.Err() != nil {
|
||||
return nil, res.Err()
|
||||
}
|
||||
|
||||
var data any
|
||||
err = res.Decode(&data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
tmp, _ := bson.MarshalExtJSON(data, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
150
internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go
Normal file
150
internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbfindone_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
projectPayload: |
|
||||
{ name: 1, age: 1 }
|
||||
projectParams: []
|
||||
sortPayload: |
|
||||
{ timestamp: -1 }
|
||||
sortParams: []
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbfindone.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-find-one",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
ProjectPayload: "{ name: 1, age: 1 }\n",
|
||||
ProjectParams: tools.Parameters{},
|
||||
SortPayload: "{ timestamp: -1 }\n",
|
||||
SortParams: tools.Parameters{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-find-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-find-one"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
183
internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go
Normal file
183
internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go
Normal file
@@ -0,0 +1,183 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbinsertmany
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-insert-many"
|
||||
|
||||
const paramDataKey = "data"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
dataParam := tools.NewStringParameterWithRequired(paramDataKey, "the JSON payload to insert, should be a JSON array of documents", true)
|
||||
parameters := tools.Parameters{dataParam}
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
parameters.Manifest(),
|
||||
)
|
||||
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
payloadMcpManifest := dataParam.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := map[string]tools.ParameterMcpManifest{
|
||||
paramDataKey: payloadMcpManifest,
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: []string{paramDataKey},
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
Canonical: cfg.Canonical,
|
||||
PayloadParams: parameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
Canonical bool `yaml:"canonical" validation:"required"` //i want to force the user to choose
|
||||
PayloadParams tools.Parameters
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
if len(params) == 0 {
|
||||
return nil, errors.New("no input found")
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
var jsonData, ok = paramsMap[paramDataKey].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("no input found")
|
||||
}
|
||||
|
||||
var data = []any{}
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).InsertMany(ctx, data, options.InsertMany())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.InsertedIDs, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.PayloadParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbinsertmany_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-insert-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
canonical: true
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbinsertmany.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-insert-many",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Description: "some description",
|
||||
Canonical: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-insert-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-insert-many"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
181
internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go
Normal file
181
internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbinsertone
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-insert-one"
|
||||
|
||||
const dataParamsKey = "data"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
payloadParams := tools.NewStringParameterWithRequired(dataParamsKey, "the JSON payload to insert, should be a JSON object", true)
|
||||
parameters := tools.Parameters{payloadParams}
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
parameters.Manifest(),
|
||||
)
|
||||
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
payloadMcpManifest := payloadParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := map[string]tools.ParameterMcpManifest{
|
||||
dataParamsKey: payloadMcpManifest,
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: []string{dataParamsKey},
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
Canonical: cfg.Canonical,
|
||||
PayloadParams: parameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
Canonical bool `yaml:"canonical" validation:"required"`
|
||||
PayloadParams tools.Parameters `yaml:"payloadParams" validate:"required"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
if len(params) == 0 {
|
||||
return nil, errors.New("no input found")
|
||||
}
|
||||
// use the first, assume it's a string
|
||||
var jsonData, ok = params[0].Value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("no input found")
|
||||
}
|
||||
|
||||
var data any
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).InsertOne(ctx, data, options.InsertOne())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.InsertedID, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.PayloadParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
124
internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go
Normal file
124
internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbinsertone_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-insert-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
canonical: true
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbinsertone.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-insert-one",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Canonical: true,
|
||||
Description: "some description",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-insert-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
canonical: true
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-insert-one"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
225
internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go
Normal file
225
internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbupdatemany
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-update-many"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
UpdatePayload string `yaml:"updatePayload" validate:"required"`
|
||||
UpdateParams tools.Parameters `yaml:"updateParams" validate:"required"`
|
||||
Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose
|
||||
Upsert bool `yaml:"upsert"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.FilterParams, cfg.UpdateParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
updateMcpManifest := cfg.UpdateParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
updateMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range updateMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
Canonical: cfg.Canonical,
|
||||
Upsert: cfg.Upsert,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
UpdatePayload: cfg.UpdatePayload,
|
||||
UpdateParams: cfg.UpdateParams,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
UpdatePayload string `yaml:"updatePayload" validate:"required"`
|
||||
UpdateParams tools.Parameters `yaml:"updateParams" validate:"required"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Canonical bool `yaml:"canonical" validation:"required"` //i want to force the user to choose
|
||||
Upsert bool `yaml:"upsert"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), t.Canonical, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
|
||||
updateString, err := common.GetUpdate(t.UpdateParams, t.UpdatePayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get update: %w", err)
|
||||
}
|
||||
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
|
||||
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbupdatemany_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-update-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
canonical: true
|
||||
updatePayload: |
|
||||
{ $set: { name: {{json .name}} } }
|
||||
updateParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbupdatemany.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-update-many",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdatePayload: "{ $set: { name: {{json .name}} } }\n",
|
||||
UpdateParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
Description: "some description",
|
||||
Canonical: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-update-many
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
canonical: true
|
||||
updatePayload: |
|
||||
{ $set: { name: {{json .name}} } }
|
||||
updateParams:
|
||||
- name: data
|
||||
type: string
|
||||
description: the content in json
|
||||
`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-update-many"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
221
internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go
Normal file
221
internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go
Normal file
@@ -0,0 +1,221 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package mongodbupdateone
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/common"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const kind string = "mongodb-update-one"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Collection string `yaml:"collection" validate:"required"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters `yaml:"filterParams" validate:"required"`
|
||||
UpdatePayload string `yaml:"updatePayload" validate:"required"`
|
||||
UpdateParams tools.Parameters `yaml:"updateParams" validate:"required"`
|
||||
|
||||
Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose
|
||||
Upsert bool `yaml:"upsert"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*mongosrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind)
|
||||
}
|
||||
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.FilterParams, cfg.FilterParams, cfg.UpdateParams)
|
||||
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.FilterParams.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
if _, exists := seenNames[param.Name]; exists {
|
||||
return nil, fmt.Errorf("parameter name must be unique across filterParams, projectParams, and sortParams. Duplicate parameter: %s", param.Name)
|
||||
}
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
filterMcpManifest := cfg.FilterParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
filterMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range filterMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Collection: cfg.Collection,
|
||||
FilterPayload: cfg.FilterPayload,
|
||||
FilterParams: cfg.FilterParams,
|
||||
UpdatePayload: cfg.UpdatePayload,
|
||||
UpdateParams: cfg.UpdateParams,
|
||||
Canonical: cfg.Canonical,
|
||||
Upsert: cfg.Upsert,
|
||||
AllParams: allParameters,
|
||||
database: s.Client.Database(cfg.Database),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Description string `yaml:"description"`
|
||||
Collection string `yaml:"collection"`
|
||||
FilterPayload string `yaml:"filterPayload" validate:"required"`
|
||||
FilterParams tools.Parameters
|
||||
UpdatePayload string `yaml:"updatePayload" validate:"required"`
|
||||
UpdateParams tools.Parameters
|
||||
AllParams tools.Parameters
|
||||
Canonical bool `yaml:"canonical" validation:"required"`
|
||||
Upsert bool `yaml:"upsert"`
|
||||
|
||||
database *mongo.Database
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
filterString, err := common.ParsePayloadTemplate(t.FilterParams, t.FilterPayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error populating filter: %s", err)
|
||||
}
|
||||
|
||||
var filter = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
|
||||
updateString, err := common.GetUpdate(t.UpdateParams, t.UpdatePayload, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get update: %w", err)
|
||||
}
|
||||
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), t.Canonical, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := t.database.Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
|
||||
return res.ModifiedCount, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
159
internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go
Normal file
159
internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodbupdateone_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-update-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
database: test_db
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name: {{json .name}} }
|
||||
filterParams:
|
||||
- name: name
|
||||
type: string
|
||||
description: small description
|
||||
updatePayload: |
|
||||
{ $set : { item: {{json .item}} } }
|
||||
updateParams:
|
||||
- name: item
|
||||
type: string
|
||||
description: small description
|
||||
canonical: true
|
||||
upsert: true
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mongodbupdateone.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mongodb-update-one",
|
||||
Source: "my-instance",
|
||||
AuthRequired: []string{},
|
||||
Database: "test_db",
|
||||
Collection: "test_coll",
|
||||
Canonical: true,
|
||||
FilterPayload: "{ name: {{json .name}} }\n",
|
||||
FilterParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "name",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
UpdatePayload: "{ $set : { item: {{json .item}} } }\n",
|
||||
UpdateParams: tools.Parameters{
|
||||
&tools.StringParameter{
|
||||
CommonParameter: tools.CommonParameter{
|
||||
Name: "item",
|
||||
Type: "string",
|
||||
Desc: "small description",
|
||||
},
|
||||
},
|
||||
},
|
||||
Upsert: true,
|
||||
Description: "some description",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlMongoQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid method",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mongodb-update-one
|
||||
source: my-instance
|
||||
description: some description
|
||||
collection: test_coll
|
||||
filterPayload: |
|
||||
{ name : {{json .name}} }`,
|
||||
err: `unable to parse tool "example_tool" as kind "mongodb-update-one"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
760
tests/mongodb/mongodb_integration_test.go
Normal file
760
tests/mongodb/mongodb_integration_test.go
Normal file
@@ -0,0 +1,760 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
var (
|
||||
MongoDbSourceKind = "mongodb"
|
||||
MongoDbToolKind = "mongodb-find"
|
||||
MongoDbUri = os.Getenv("MONGODB_URI")
|
||||
MongoDbDatabase = os.Getenv("MONGODB_DATABASE")
|
||||
ServiceAccountEmail = os.Getenv("SERVICE_ACCOUNT_EMAIL")
|
||||
)
|
||||
|
||||
func getMongoDBVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case MongoDbUri:
|
||||
t.Fatal("'MongoDbUri' not set")
|
||||
case MongoDbDatabase:
|
||||
t.Fatal("'MongoDbDatabase' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": MongoDbSourceKind,
|
||||
"uri": MongoDbUri,
|
||||
}
|
||||
}
|
||||
|
||||
func initMongoDbDatabase(ctx context.Context, uri, database string) (*mongo.Database, error) {
|
||||
// Create a new mongodb Database
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(uri))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to mongodb: %s", err)
|
||||
}
|
||||
err = client.Ping(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect to mongodb: %s", err)
|
||||
}
|
||||
return client.Database(database), nil
|
||||
}
|
||||
|
||||
func TestMongoDBToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getMongoDBVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
database, err := initMongoDbDatabase(ctx, MongoDbUri, MongoDbDatabase)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create MongoDB connection: %s", err)
|
||||
}
|
||||
|
||||
// set up data for param tool
|
||||
//setupMongoDB(t, ctx, database)
|
||||
teardownDB := setupMongoDB(t, ctx, database)
|
||||
defer teardownDB(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := getMongoDBToolsConfig(sourceConfig, MongoDbToolKind)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeIdNullWant, nullString, mcpInvokeParamWant := getMongoDBWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullString, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
|
||||
delete1Want := "1"
|
||||
deleteManyWant := "2"
|
||||
RunToolDeleteInvokeTest(t, delete1Want, deleteManyWant)
|
||||
insert1Want := `["68666e1035bb36bf1b4d47fb"]`
|
||||
insertManyWant := `["68667a6436ec7d0363668db7","68667a6436ec7d0363668db8","68667a6436ec7d0363668db9"]`
|
||||
RunToolInsertInvokeTest(t, insert1Want, insertManyWant)
|
||||
update1Want := "1"
|
||||
updateManyWant := "[2,0,2]"
|
||||
RunToolUpdateInvokeTest(t, update1Want, updateManyWant)
|
||||
aggregate1Want := `[{"id":2}]`
|
||||
aggregateManyWant := `[{"id":500},{"id":501}]`
|
||||
RunToolAggregateInvokeTest(t, aggregate1Want, aggregateManyWant)
|
||||
|
||||
}
|
||||
|
||||
func RunToolDeleteInvokeTest(t *testing.T, delete1Want, deleteManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-delete-one-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-delete-one-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "id" : 100 }`)),
|
||||
want: delete1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-delete-many-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-delete-many-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "id" : 101 }`)),
|
||||
want: deleteManyWant,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolInsertInvokeTest(t *testing.T, insert1Want, insertManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-insert-one-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-insert-one-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "data" : "{ \"_id\": { \"$oid\": \"68666e1035bb36bf1b4d47fb\" }, \"id\" : 200 }" }"`)),
|
||||
want: insert1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-insert-many-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-insert-many-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "data" : "[{ \"_id\": { \"$oid\": \"68667a6436ec7d0363668db7\"} , \"id\" : 201 }, { \"_id\" : { \"$oid\": \"68667a6436ec7d0363668db8\"}, \"id\" : 202 }, { \"_id\": { \"$oid\": \"68667a6436ec7d0363668db9\"}, \"id\": 203 }]" }`)),
|
||||
want: insertManyWant,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolUpdateInvokeTest(t *testing.T, update1Want, updateManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-update-one-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-update-one-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "id": 300, "name": "Bob" }`)),
|
||||
want: update1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-update-many-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-update-many-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "id": 400, "name" : "Alice" }`)),
|
||||
want: updateManyWant,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func RunToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateManyWant string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-aggregate-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-aggregate-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "name": "Jane" }`)),
|
||||
want: aggregate1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-aggregate-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-aggregate-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "name" : "ToBeAggregated" }`)),
|
||||
want: aggregateManyWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-read-only-aggregate-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-read-only-aggregate-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "name" : "ToBeAggregated" }`)),
|
||||
want: "",
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-read-write-aggregate-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-read-write-aggregate-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{ "name" : "ToBeAggregated" }`)),
|
||||
want: "[]",
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupMongoDB(t *testing.T, ctx context.Context, database *mongo.Database) func(*testing.T) {
|
||||
collectionName := "test_collection"
|
||||
|
||||
documents := []map[string]any{
|
||||
{"_id": 1, "id": 1, "name": "Alice", "email": ServiceAccountEmail},
|
||||
{"_id": 2, "id": 2, "name": "Jane"},
|
||||
{"_id": 3, "id": 3, "name": "Sid"},
|
||||
{"_id": 4, "id": 4, "name": nil},
|
||||
{"_id": 5, "id": 3, "name": "Alice", "email": "alice@gmail.com"},
|
||||
{"_id": 6, "id": 100, "name": "ToBeDeleted", "email": "bob@gmail.com"},
|
||||
{"_id": 7, "id": 101, "name": "ToBeDeleted", "email": "bob1@gmail.com"},
|
||||
{"_id": 8, "id": 101, "name": "ToBeDeleted", "email": "bob2@gmail.com"},
|
||||
{"_id": 9, "id": 300, "name": "ToBeUpdatedToBob", "email": "bob@gmail.com"},
|
||||
{"_id": 10, "id": 400, "name": "ToBeUpdatedToAlice", "email": "alice@gmail.com"},
|
||||
{"_id": 11, "id": 400, "name": "ToBeUpdatedToAlice", "email": "alice@gmail.com"},
|
||||
{"_id": 12, "id": 500, "name": "ToBeAggregated", "email": "agatha@gmail.com"},
|
||||
{"_id": 13, "id": 501, "name": "ToBeAggregated", "email": "agatha@gmail.com"},
|
||||
}
|
||||
for _, doc := range documents {
|
||||
_, err := database.Collection(collectionName).InsertOne(ctx, doc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
err := database.Collection(collectionName).Drop(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any {
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": tests.ClientId,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": "mongodb-find-one",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "_id" : 3 }`,
|
||||
"filterParams": []any{},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "id" : {{ .id }}, "name" : {{json .name }} }`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "user id",
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-tool-by-id": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "id" : {{ .id }} }`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "user id",
|
||||
},
|
||||
},
|
||||
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-tool-by-name": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "name" : {{ .name }} }`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
"required": false,
|
||||
},
|
||||
},
|
||||
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-array-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with array.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "name": { "$in": {{json .nameArray}} }, "_id": 5 })`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "nameArray",
|
||||
"type": "array",
|
||||
"description": "user names",
|
||||
"items": map[string]any{
|
||||
"name": "username",
|
||||
"type": "string",
|
||||
"description": "string item"},
|
||||
},
|
||||
},
|
||||
"projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test authenticated parameters.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "email" : {{json .email }} }`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "email",
|
||||
"type": "string",
|
||||
"description": "user email",
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth",
|
||||
"field": "email",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"projectPayload": `{ "_id": 0, "name" : 1 }`,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-auth-required-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test auth required invocation.",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "_id": 3, "id": 3 }`,
|
||||
"filterParams": []any{},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-fail-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test statement with incorrect syntax.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "id" ; 1 }"}`,
|
||||
"filterParams": []any{},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-delete-one-tool": map[string]any{
|
||||
"kind": "mongodb-delete-one",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test deleting an entry.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "id" : 100 }"}`,
|
||||
"filterParams": []any{},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-delete-many-tool": map[string]any{
|
||||
"kind": "mongodb-delete-many",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test deleting multiple entries.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"filterPayload": `{ "id" : 101 }"}`,
|
||||
"filterParams": []any{},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-insert-one-tool": map[string]any{
|
||||
"kind": "mongodb-insert-one",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test inserting an entry.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-insert-many-tool": map[string]any{
|
||||
"kind": "mongodb-insert-many",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test inserting multiple entries.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-update-one-tool": map[string]any{
|
||||
"kind": "mongodb-update-one",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test updating an entry.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"filterPayload": `{ "id" : 300 }`,
|
||||
"filterParams": []any{},
|
||||
"updatePayload": `{ "$set" : { "name": {{json .name}} } }`,
|
||||
"updateParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-update-many-tool": map[string]any{
|
||||
"kind": "mongodb-update-many",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test updating multiple entries.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"filterPayload": `{ "id" : {{ .id }} }`,
|
||||
"filterParams": []map[string]any{
|
||||
{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "id",
|
||||
},
|
||||
},
|
||||
"updatePayload": `{ "$set" : { "name": {{json .name}} } }`,
|
||||
"updateParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-aggregate-tool": map[string]any{
|
||||
"kind": "mongodb-aggregate",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test an aggregation.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"pipelinePayload": `[{ "$match" : { "name": {{json .name}} } }, { "$project" : { "id" : 1, "_id" : 0 }}]`,
|
||||
"pipelineParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-read-only-aggregate-tool": map[string]any{
|
||||
"kind": "mongodb-aggregate",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test an aggregation.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"readOnly": true,
|
||||
"pipelinePayload": `[{ "$match" : { "name": {{json .name}} } }, { "$out" : "target_collection" }]`,
|
||||
"pipelineParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
"my-read-write-aggregate-tool": map[string]any{
|
||||
"kind": "mongodb-aggregate",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test an aggregation.",
|
||||
"authRequired": []string{},
|
||||
"collection": "test_collection",
|
||||
"canonical": true,
|
||||
"readOnly": false,
|
||||
"pipelinePayload": `[{ "$match" : { "name": {{json .name}} } }, { "$out" : "target_collection" }]`,
|
||||
"pipelineParams": []map[string]any{
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
"database": MongoDbDatabase,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolsFile
|
||||
|
||||
}
|
||||
|
||||
func getMongoDBWants() (string, string, string, string, string, string) {
|
||||
select1Want := `[{"_id":3,"id":3,"name":"Sid"}]`
|
||||
failInvocationWant := `invalid JSON input: missing colon after key `
|
||||
invokeParamWant := `[{"_id":5,"id":3,"name":"Alice"}]`
|
||||
invokeParamWantNull := `[{"_id":4,"id":4,"name":null}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"_id\":5,\"id\":3,\"name\":\"Alice\"}"}]}}`
|
||||
nullString := "null"
|
||||
return select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, nullString, mcpInvokeParamWant
|
||||
}
|
||||
Reference in New Issue
Block a user