mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-21 13:27:55 -05:00
Compare commits
11 Commits
multi_preb
...
fix-adk-js
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48794b1bf6 | ||
|
|
42bff296e3 | ||
|
|
0b40d0660e | ||
|
|
40b95f9552 | ||
|
|
5f0e15b444 | ||
|
|
635ea97ea9 | ||
|
|
bf4921158d | ||
|
|
e4f60e5633 | ||
|
|
d7af21bdde | ||
|
|
adc9589766 | ||
|
|
c25a2330fe |
@@ -87,7 +87,7 @@ steps:
|
||||
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv:
|
||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
|
||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -134,7 +134,7 @@ steps:
|
||||
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
||||
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
|
||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -293,7 +293,7 @@ steps:
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud Healthcare API" \
|
||||
cloudhealthcare \
|
||||
cloudhealthcare || echo "Integration tests failed."
|
||||
cloudhealthcare
|
||||
|
||||
- id: "postgres"
|
||||
name: golang:1
|
||||
@@ -305,7 +305,7 @@ steps:
|
||||
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
||||
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
|
||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -964,6 +964,13 @@ steps:
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
# Common secrets
|
||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||
env: CLIENT_ID
|
||||
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
|
||||
env: API_KEY
|
||||
|
||||
# Resource-specific secrets
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
env: CLOUD_SQL_POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
||||
@@ -980,8 +987,6 @@ availableSecrets:
|
||||
env: POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
||||
env: POSTGRES_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||
env: CLIENT_ID
|
||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
||||
env: NEO4J_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
||||
|
||||
@@ -77,11 +77,24 @@ run_orch_test() {
|
||||
setup_orch_table
|
||||
|
||||
cd "$orch_dir"
|
||||
|
||||
echo "--- Debugging npm config for $orch_name ---"
|
||||
npm config list
|
||||
|
||||
echo "--- Active Registry for $orch_name ---"
|
||||
npm config get registry
|
||||
|
||||
echo "--- Inspecting .npmrc files ---"
|
||||
[ -f ".npmrc" ] && echo "Local .npmrc content:" && cat .npmrc
|
||||
[ -f "$HOME/.npmrc" ] && echo "Global .npmrc content:" && cat "$HOME/.npmrc"
|
||||
|
||||
export GPKG_DEBUG=1
|
||||
|
||||
echo "Installing dependencies for $orch_name..."
|
||||
if [ -f "package-lock.json" ]; then
|
||||
npm ci
|
||||
npm ci --loglevel verbose
|
||||
else
|
||||
npm install
|
||||
npm install --loglevel verbose
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
67
cmd/root.go
67
cmd/root.go
@@ -315,15 +315,15 @@ func Execute() {
|
||||
type Command struct {
|
||||
*cobra.Command
|
||||
|
||||
cfg server.ServerConfig
|
||||
logger log.Logger
|
||||
tools_file string
|
||||
tools_files []string
|
||||
tools_folder string
|
||||
prebuiltConfigs []string
|
||||
inStream io.Reader
|
||||
outStream io.Writer
|
||||
errStream io.Writer
|
||||
cfg server.ServerConfig
|
||||
logger log.Logger
|
||||
tools_file string
|
||||
tools_files []string
|
||||
tools_folder string
|
||||
prebuiltConfig string
|
||||
inStream io.Reader
|
||||
outStream io.Writer
|
||||
errStream io.Writer
|
||||
}
|
||||
|
||||
// NewCommand returns a Command object representing an invocation of the CLI.
|
||||
@@ -376,16 +376,17 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||
// Fetch prebuilt tools sources to customize the help description
|
||||
prebuiltHelp := fmt.Sprintf(
|
||||
"Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.",
|
||||
"Use a prebuilt tool configuration by source type. Allowed: '%s'.",
|
||||
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
|
||||
)
|
||||
flags.StringSliceVar(&cmd.prebuiltConfigs, "prebuilt", []string{}, prebuiltHelp)
|
||||
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp)
|
||||
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
||||
flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
|
||||
flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
|
||||
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
||||
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
@@ -866,32 +867,24 @@ func run(cmd *Command) error {
|
||||
var allToolsFiles []ToolsFile
|
||||
|
||||
// Load Prebuilt Configuration
|
||||
|
||||
if len(cmd.prebuiltConfigs) > 0 {
|
||||
slices.Sort(cmd.prebuiltConfigs)
|
||||
sourcesList := strings.Join(cmd.prebuiltConfigs, ", ")
|
||||
logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList)
|
||||
cmd.logger.InfoContext(ctx, logMsg)
|
||||
|
||||
for _, configName := range cmd.prebuiltConfigs {
|
||||
buf, err := prebuiltconfigs.Get(configName)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// Update version string
|
||||
cmd.cfg.Version += "+prebuilt." + configName
|
||||
|
||||
// Parse into ToolsFile struct
|
||||
parsed, err := parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, parsed)
|
||||
if cmd.prebuiltConfig != "" {
|
||||
buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig)
|
||||
if err != nil {
|
||||
cmd.logger.ErrorContext(ctx, err.Error())
|
||||
return err
|
||||
}
|
||||
logMsg := fmt.Sprint("Using prebuilt tool configuration for ", cmd.prebuiltConfig)
|
||||
cmd.logger.InfoContext(ctx, logMsg)
|
||||
// Append prebuilt.source to Version string for the User Agent
|
||||
cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig
|
||||
|
||||
parsed, err := parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, parsed)
|
||||
}
|
||||
|
||||
// Determine if Custom Files should be loaded
|
||||
@@ -899,7 +892,7 @@ func run(cmd *Command) error {
|
||||
isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != ""
|
||||
|
||||
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
|
||||
useDefaultToolsFile := len(cmd.prebuiltConfigs) == 0 && !isCustomConfigured
|
||||
useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured
|
||||
|
||||
if useDefaultToolsFile {
|
||||
cmd.tools_file = "tools.yaml"
|
||||
|
||||
@@ -70,6 +70,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
|
||||
if c.AllowedHosts == nil {
|
||||
c.AllowedHosts = []string{"*"}
|
||||
}
|
||||
if c.UserAgentMetadata == nil {
|
||||
c.UserAgentMetadata = []string{}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -230,6 +233,13 @@ func TestServerConfigFlags(t *testing.T) {
|
||||
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "user agent metadata",
|
||||
args: []string{"--user-agent-metadata", "foo,bar"},
|
||||
want: withDefaults(server.ServerConfig{
|
||||
UserAgentMetadata: []string{"foo", "bar"},
|
||||
}),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -420,27 +430,17 @@ func TestPrebuiltFlag(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
args []string
|
||||
want []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
args: []string{},
|
||||
want: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
desc: "single prebuilt flag",
|
||||
args: []string{"--prebuilt", "alloydb"},
|
||||
want: []string{"alloydb"},
|
||||
},
|
||||
{
|
||||
desc: "multiple prebuilt flags",
|
||||
args: []string{"--prebuilt", "alloydb", "--prebuilt", "bigquery"},
|
||||
want: []string{"alloydb", "bigquery"},
|
||||
},
|
||||
{
|
||||
desc: "comma separated prebuilt flags",
|
||||
args: []string{"--prebuilt", "alloydb,bigquery"},
|
||||
want: []string{"alloydb", "bigquery"},
|
||||
desc: "custom pre built flag",
|
||||
args: []string{"--tools-file", "alloydb"},
|
||||
want: "alloydb",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
@@ -449,8 +449,8 @@ func TestPrebuiltFlag(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error invoking command: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(c.prebuiltConfigs, tc.want); diff != "" {
|
||||
t.Fatalf("got %v, want %v, diff %s", c.prebuiltConfigs, tc.want, diff)
|
||||
if c.tools_file != tc.want {
|
||||
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2073,12 +2073,6 @@ authSources:
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "sqlite called twice error",
|
||||
args: []string{"--prebuilt", "sqlite", "--prebuilt", "sqlite"},
|
||||
wantErr: true,
|
||||
errString: "resource conflicts detected",
|
||||
},
|
||||
{
|
||||
desc: "tool conflict error",
|
||||
args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile},
|
||||
|
||||
2222
docs/en/getting-started/quickstart/js/adk/package-lock.json
generated
2222
docs/en/getting-started/quickstart/js/adk/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,7 @@
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@google/adk": "^0.1.3",
|
||||
"@google/adk": "^0.2.4",
|
||||
"@toolbox-sdk/adk": "^0.1.5"
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,7 @@
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@llamaindex/google": "^0.3.20",
|
||||
"@llamaindex/google": "^0.4.0",
|
||||
"@llamaindex/workflow": "^1.1.22",
|
||||
"@toolbox-sdk/core": "^0.1.2",
|
||||
"llamaindex": "^0.12.0"
|
||||
|
||||
@@ -16,7 +16,7 @@ description: >
|
||||
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
|
||||
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
|
||||
| `-p` | `--port` | Port the server will listen on. | `5000` |
|
||||
| | `--prebuilt` | Use one or more prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
|
||||
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
|
||||
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
|
||||
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
|
||||
@@ -27,6 +27,7 @@ description: >
|
||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
|
||||
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
|
||||
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
|
||||
| `-v` | `--version` | version for toolbox | |
|
||||
|
||||
## Examples
|
||||
@@ -50,11 +51,6 @@ description: >
|
||||
|
||||
# Server with prebuilt + custom tools configurations
|
||||
./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres
|
||||
|
||||
# Server with multiple prebuilt tools configurations
|
||||
./toolbox --prebuilt alloydb-postgres,alloydb-postgres-admin
|
||||
# OR
|
||||
./toolbox --prebuilt alloydb-postgres --prebuilt alloydb-postgres-admin
|
||||
```
|
||||
|
||||
### Tool Configuration Sources
|
||||
@@ -75,7 +71,7 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations
|
||||
|
||||
**Prebuilt Configurations:**
|
||||
|
||||
- `--prebuilt`: Use one or more predefined configurations for specific database types (e.g.,
|
||||
- `--prebuilt`: Use predefined configurations for specific database types (e.g.,
|
||||
'bigquery', 'postgres', 'spanner'). See [Prebuilt Tools
|
||||
Reference](prebuilt-tools.md) for allowed values.
|
||||
|
||||
|
||||
@@ -16,9 +16,6 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
|
||||
{{< notice tip >}}
|
||||
You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or
|
||||
`--tools-folder` to combine prebuilt configs with custom tools.
|
||||
|
||||
You can also combine multiple prebuilt configs.
|
||||
|
||||
See [Usage Examples](../reference/cli.md#examples).
|
||||
{{< /notice >}}
|
||||
|
||||
|
||||
@@ -64,12 +64,14 @@ type ServerConfig struct {
|
||||
Stdio bool
|
||||
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
||||
DisableReload bool
|
||||
// UI indicates if Toolbox UI endpoints (/ui) are available
|
||||
// UI indicates if Toolbox UI endpoints (/ui) are available.
|
||||
UI bool
|
||||
// Specifies a list of origins permitted to access this server.
|
||||
AllowedOrigins []string
|
||||
// Specifies a list of hosts permitted to access this server
|
||||
// Specifies a list of hosts permitted to access this server.
|
||||
AllowedHosts []string
|
||||
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
|
||||
UserAgentMetadata []string
|
||||
}
|
||||
|
||||
type logFormat string
|
||||
|
||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -64,7 +64,11 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]prompts.Promptset,
|
||||
error,
|
||||
) {
|
||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
||||
metadataStr := cfg.Version
|
||||
if len(cfg.UserAgentMetadata) > 0 {
|
||||
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, metadataStr)
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -103,10 +103,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
}
|
||||
|
||||
@@ -131,9 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var opts []googleapi.CallOption
|
||||
|
||||
@@ -161,9 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var summary bool
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDataset(tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDICOMStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -130,9 +130,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.ListDICOMStores(tokenStr)
|
||||
}
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.ListFHIRStores(tokenStr)
|
||||
}
|
||||
|
||||
@@ -127,9 +127,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -140,9 +140,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
|
||||
@@ -138,9 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
|
||||
@@ -133,9 +133,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||
if err != nil {
|
||||
|
||||
@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semanti search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
|
||||
@@ -112,8 +112,7 @@ func TestHealthcareToolEndpoints(t *testing.T) {
|
||||
fhirStoreID := "fhir-store-" + uuid.New().String()
|
||||
dicomStoreID := "dicom-store-" + uuid.New().String()
|
||||
|
||||
patient1ID, patient2ID, teardown := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
||||
defer teardown(t)
|
||||
patient1ID, patient2ID := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
||||
|
||||
toolsFile := getToolsConfig(sourceConfig)
|
||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||
@@ -173,10 +172,8 @@ func TestHealthcareToolWithStoreRestriction(t *testing.T) {
|
||||
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
|
||||
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
|
||||
|
||||
_, _, teardownAllowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
||||
defer teardownAllowedStores(t)
|
||||
_, _, teardownDisallowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
||||
defer teardownDisallowedStores(t)
|
||||
setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
||||
setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
||||
|
||||
// Configure source with dataset restriction.
|
||||
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
|
||||
@@ -257,7 +254,7 @@ func newHealthcareService(ctx context.Context) (*healthcare.Service, error) {
|
||||
return healthcareService, nil
|
||||
}
|
||||
|
||||
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string, func(*testing.T)) {
|
||||
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string) {
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
|
||||
var err error
|
||||
|
||||
@@ -266,12 +263,24 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
||||
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
|
||||
t.Fatalf("failed to create fhir store: %v", err)
|
||||
}
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete fhir store: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create DICOM store
|
||||
dicomStore := &healthcare.DicomStore{}
|
||||
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
|
||||
t.Fatalf("failed to create dicom store: %v", err)
|
||||
}
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete dicom store: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create Patient 1
|
||||
patient1Body := bytes.NewBuffer([]byte(`{
|
||||
@@ -317,15 +326,7 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
||||
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
|
||||
}
|
||||
|
||||
teardown := func(t *testing.T) {
|
||||
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete fhir store: %v", err)
|
||||
}
|
||||
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete dicom store: %v", err)
|
||||
}
|
||||
}
|
||||
return patient1ID, patient2ID, teardown
|
||||
return patient1ID, patient2ID
|
||||
}
|
||||
|
||||
func getToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
|
||||
@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semantic search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
251
tests/embedding.go
Normal file
251
tests/embedding.go
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package tests contains end to end tests meant to verify the Toolbox Server
|
||||
// works as expected when executed as a binary.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var apiKey = os.Getenv("API_KEY")
|
||||
|
||||
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
|
||||
// with configurable tool kind and SQL statements.
|
||||
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
|
||||
config["embeddingModels"] = map[string]any{
|
||||
"gemini_model": map[string]any{
|
||||
"kind": "gemini",
|
||||
"model": "gemini-embedding-001",
|
||||
"apiKey": apiKey,
|
||||
"dimension": 768,
|
||||
},
|
||||
}
|
||||
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
|
||||
tools["insert_docs"] = map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Stores content and its vector embedding into the documents table.",
|
||||
"statement": insertStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "content",
|
||||
"type": "string",
|
||||
"description": "The text content associated with the vector.",
|
||||
},
|
||||
map[string]any{
|
||||
"name": "text_to_embed",
|
||||
"type": "string",
|
||||
"description": "The text content used to generate the vector.",
|
||||
"embeddedBy": "gemini_model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tools["search_docs"] = map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Finds the most semantically similar document to the query vector.",
|
||||
"statement": searchStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "The text content to search for.",
|
||||
"embeddedBy": "gemini_model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
|
||||
// via both HTTP and MCP endpoints and verifies the output.
|
||||
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
|
||||
// Initialize MCP session once for the MCP test cases
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
isMcp bool
|
||||
requestBody interface{}
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "HTTP invoke insert_docs",
|
||||
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
|
||||
isMcp: false,
|
||||
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
|
||||
want: insertWant,
|
||||
},
|
||||
{
|
||||
name: "HTTP invoke search_docs",
|
||||
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
|
||||
isMcp: false,
|
||||
requestBody: `{"query": "fast fox jumping"}`,
|
||||
want: searchWant,
|
||||
},
|
||||
{
|
||||
name: "MCP invoke insert_docs",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
isMcp: true,
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "mcp-insert-docs",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "insert_docs",
|
||||
"arguments": map[string]any{
|
||||
"content": "The quick brown fox jumps over the lazy dog",
|
||||
"text_to_embed": "The quick brown fox jumps over the lazy dog",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: mcpInsertWant,
|
||||
},
|
||||
{
|
||||
name: "MCP invoke search_docs",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
isMcp: true,
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "mcp-search-docs",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "search_docs",
|
||||
"arguments": map[string]any{
|
||||
"query": "fast fox jumping",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: searchWant,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var bodyReader io.Reader
|
||||
headers := map[string]string{}
|
||||
|
||||
// Prepare Request Body and Headers
|
||||
if tc.isMcp {
|
||||
reqBytes, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal mcp request: %v", err)
|
||||
}
|
||||
bodyReader = bytes.NewBuffer(reqBytes)
|
||||
if sessionId != "" {
|
||||
headers["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
} else {
|
||||
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
|
||||
}
|
||||
|
||||
// Send Request
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Normalize Response to get the actual tool result string
|
||||
var got string
|
||||
if tc.isMcp {
|
||||
var mcpResp struct {
|
||||
Result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
|
||||
t.Fatalf("error parsing mcp response: %s", err)
|
||||
}
|
||||
if len(mcpResp.Result.Content) > 0 {
|
||||
got = mcpResp.Result.Content[0].Text
|
||||
}
|
||||
} else {
|
||||
var httpResp map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &httpResp); err != nil {
|
||||
t.Fatalf("error parsing http response: %s", err)
|
||||
}
|
||||
if res, ok := httpResp["result"].(string); ok {
|
||||
got = res
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetupPostgresVectorTable sets up the vector extension and a vector table
|
||||
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
|
||||
t.Helper()
|
||||
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
|
||||
t.Fatalf("failed to create vector extension: %v", err)
|
||||
}
|
||||
|
||||
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
content TEXT,
|
||||
embedding vector(768)
|
||||
)`, tableName)
|
||||
|
||||
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
||||
t.Fatalf("failed to create table %s: %v", tableName, err)
|
||||
}
|
||||
|
||||
return tableName, func(t *testing.T) {
|
||||
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
|
||||
t.Errorf("failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
|
||||
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
|
||||
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
|
||||
return insertStmt, searchStmt
|
||||
}
|
||||
@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semantic search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||
}
|
||||
|
||||
@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
||||
name, _ := tableMap["object_name"].(string)
|
||||
|
||||
// Only keep the table if it matches expected test tables
|
||||
if name == tableNameParam || name == tableNameAuth {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user