Compare commits

..

11 Commits

Author SHA1 Message Date
AnmolShukla2002
48794b1bf6 fresh installation check 2026-01-21 15:28:01 +05:30
AnmolShukla2002
42bff296e3 updating adk deps to 2.x v 2026-01-21 15:03:58 +05:30
AnmolShukla2002
0b40d0660e fixing adk js issue 2026-01-21 14:51:50 +05:30
AnmolShukla2002
40b95f9552 check for internal reigstry redirection 2026-01-21 14:43:20 +05:30
AnmolShukla2002
5f0e15b444 debugging password error 2026-01-21 14:23:39 +05:30
Anmol Shukla
635ea97ea9 Merge branch 'main' into fix-adk-js 2026-01-21 12:11:17 +05:30
AnmolShukla2002
bf4921158d deps: fix adk and llamaindex sample deps issue 2026-01-21 12:08:51 +05:30
Wenxin Du
e4f60e5633 fix(embeddingModel): add embedding model to MCP handler (#2310)
- Add embedding model to mcp handlers
- Add integration tests
2026-01-21 00:20:11 +00:00
vaibhavba-google
d7af21bdde tests(cloudhealthcare): use t.Cleanup() instead of defer (#2332)
## Description

Use t.Cleanup() to register cleanup of FHIR and DICOM stores immediately
after creation. This fixes the uncleaned FHIR/DICOM stores that remain
in the project(In the earlier implementation, teardown does not get
triggered if the test failed).

🛠️ Fixes #1986

---------

Co-authored-by: Yuan Teoh <yuanteoh@google.com>
Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2026-01-20 14:58:33 -08:00
Yuan Teoh
adc9589766 feat: add new user-agent-metadata flag (#2302)
## Description

Add a new `--user-agent-metadata` flag that allows user to append
additional user agent metadata. The flag takes in []string and will
concatenate it with `.`.

```
go run . --user-agent-metadata=foo
```
 produces `0.25.0+dev.darwin.arm64+foo` user agent string

```
go run . --user-agent-metadata=foo,bar
```
produces `0.25.0+dev.darwin.arm64+foo+bar` user agent string

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2026-01-20 19:23:50 +00:00
Yuan Teoh
c25a2330fe fix: add check for client authorization before retrieving token string (#2327)
Previous refactoring (#2273) accidentally removed the authorization
checks prior to token retrieval. This issue went unnoticed because the
integration tests were disabled. I am re-adding the necessary checks.
2026-01-20 18:57:11 +00:00
37 changed files with 3087 additions and 962 deletions

View File

@@ -87,7 +87,7 @@ steps:
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv:
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes:
- name: "go"
path: "/gopath"
@@ -134,7 +134,7 @@ steps:
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
- "ALLOYDB_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes:
- name: "go"
path: "/gopath"
@@ -293,7 +293,7 @@ steps:
.ci/test_with_coverage.sh \
"Cloud Healthcare API" \
cloudhealthcare \
cloudhealthcare || echo "Integration tests failed."
cloudhealthcare
- id: "postgres"
name: golang:1
@@ -305,7 +305,7 @@ steps:
- "POSTGRES_HOST=$_POSTGRES_HOST"
- "POSTGRES_PORT=$_POSTGRES_PORT"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes:
- name: "go"
path: "/gopath"
@@ -964,6 +964,13 @@ steps:
availableSecrets:
secretManager:
# Common secrets
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
env: API_KEY
# Resource-specific secrets
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
env: CLOUD_SQL_POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
@@ -980,8 +987,6 @@ availableSecrets:
env: POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
env: POSTGRES_PASS
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
env: NEO4J_USER
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest

View File

@@ -77,11 +77,24 @@ run_orch_test() {
setup_orch_table
cd "$orch_dir"
echo "--- Debugging npm config for $orch_name ---"
npm config list
echo "--- Active Registry for $orch_name ---"
npm config get registry
echo "--- Inspecting .npmrc files ---"
[ -f ".npmrc" ] && echo "Local .npmrc content:" && cat .npmrc
[ -f "$HOME/.npmrc" ] && echo "Global .npmrc content:" && cat "$HOME/.npmrc"
export GPKG_DEBUG=1
echo "Installing dependencies for $orch_name..."
if [ -f "package-lock.json" ]; then
npm ci
npm ci --loglevel verbose
else
npm install
npm install --loglevel verbose
fi
cd ..

View File

@@ -315,15 +315,15 @@ func Execute() {
type Command struct {
*cobra.Command
cfg server.ServerConfig
logger log.Logger
tools_file string
tools_files []string
tools_folder string
prebuiltConfigs []string
inStream io.Reader
outStream io.Writer
errStream io.Writer
cfg server.ServerConfig
logger log.Logger
tools_file string
tools_files []string
tools_folder string
prebuiltConfig string
inStream io.Reader
outStream io.Writer
errStream io.Writer
}
// NewCommand returns a Command object representing an invocation of the CLI.
@@ -376,16 +376,17 @@ func NewCommand(opts ...Option) *Command {
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
// Fetch prebuilt tools sources to customize the help description
prebuiltHelp := fmt.Sprintf(
"Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.",
"Use a prebuilt tool configuration by source type. Allowed: '%s'.",
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
)
flags.StringSliceVar(&cmd.prebuiltConfigs, "prebuilt", []string{}, prebuiltHelp)
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp)
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
// TODO: Insecure by default. Might consider updating this for v1.0.0
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
// wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
@@ -866,32 +867,24 @@ func run(cmd *Command) error {
var allToolsFiles []ToolsFile
// Load Prebuilt Configuration
if len(cmd.prebuiltConfigs) > 0 {
slices.Sort(cmd.prebuiltConfigs)
sourcesList := strings.Join(cmd.prebuiltConfigs, ", ")
logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList)
cmd.logger.InfoContext(ctx, logMsg)
for _, configName := range cmd.prebuiltConfigs {
buf, err := prebuiltconfigs.Get(configName)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
// Update version string
cmd.cfg.Version += "+prebuilt." + configName
// Parse into ToolsFile struct
parsed, err := parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
allToolsFiles = append(allToolsFiles, parsed)
if cmd.prebuiltConfig != "" {
buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig)
if err != nil {
cmd.logger.ErrorContext(ctx, err.Error())
return err
}
logMsg := fmt.Sprint("Using prebuilt tool configuration for ", cmd.prebuiltConfig)
cmd.logger.InfoContext(ctx, logMsg)
// Append prebuilt.source to Version string for the User Agent
cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig
parsed, err := parseToolsFile(ctx, buf)
if err != nil {
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err)
cmd.logger.ErrorContext(ctx, errMsg.Error())
return errMsg
}
allToolsFiles = append(allToolsFiles, parsed)
}
// Determine if Custom Files should be loaded
@@ -899,7 +892,7 @@ func run(cmd *Command) error {
isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != ""
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
useDefaultToolsFile := len(cmd.prebuiltConfigs) == 0 && !isCustomConfigured
useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured
if useDefaultToolsFile {
cmd.tools_file = "tools.yaml"

View File

@@ -70,6 +70,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
if c.AllowedHosts == nil {
c.AllowedHosts = []string{"*"}
}
if c.UserAgentMetadata == nil {
c.UserAgentMetadata = []string{}
}
return c
}
@@ -230,6 +233,13 @@ func TestServerConfigFlags(t *testing.T) {
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
}),
},
{
desc: "user agent metadata",
args: []string{"--user-agent-metadata", "foo,bar"},
want: withDefaults(server.ServerConfig{
UserAgentMetadata: []string{"foo", "bar"},
}),
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
@@ -420,27 +430,17 @@ func TestPrebuiltFlag(t *testing.T) {
tcs := []struct {
desc string
args []string
want []string
want string
}{
{
desc: "default value",
args: []string{},
want: []string{},
want: "",
},
{
desc: "single prebuilt flag",
args: []string{"--prebuilt", "alloydb"},
want: []string{"alloydb"},
},
{
desc: "multiple prebuilt flags",
args: []string{"--prebuilt", "alloydb", "--prebuilt", "bigquery"},
want: []string{"alloydb", "bigquery"},
},
{
desc: "comma separated prebuilt flags",
args: []string{"--prebuilt", "alloydb,bigquery"},
want: []string{"alloydb", "bigquery"},
desc: "custom pre built flag",
args: []string{"--tools-file", "alloydb"},
want: "alloydb",
},
}
for _, tc := range tcs {
@@ -449,8 +449,8 @@ func TestPrebuiltFlag(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error invoking command: %s", err)
}
if diff := cmp.Diff(c.prebuiltConfigs, tc.want); diff != "" {
t.Fatalf("got %v, want %v, diff %s", c.prebuiltConfigs, tc.want, diff)
if c.tools_file != tc.want {
t.Fatalf("got %v, want %v", c.cfg, tc.want)
}
})
}
@@ -2073,12 +2073,6 @@ authSources:
return nil
},
},
{
desc: "sqlite called twice error",
args: []string{"--prebuilt", "sqlite", "--prebuilt", "sqlite"},
wantErr: true,
errString: "resource conflicts detected",
},
{
desc: "tool conflict error",
args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile},

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
"author": "",
"license": "ISC",
"dependencies": {
"@google/adk": "^0.1.3",
"@google/adk": "^0.2.4",
"@toolbox-sdk/adk": "^0.1.5"
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
"author": "",
"license": "ISC",
"dependencies": {
"@llamaindex/google": "^0.3.20",
"@llamaindex/google": "^0.4.0",
"@llamaindex/workflow": "^1.1.22",
"@toolbox-sdk/core": "^0.1.2",
"llamaindex": "^0.12.0"

View File

@@ -16,7 +16,7 @@ description: >
| | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` |
| | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` |
| `-p` | `--port` | Port the server will listen on. | `5000` |
| | `--prebuilt` | Use one or more prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | |
| | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | |
| | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | |
| | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | |
@@ -27,6 +27,7 @@ description: >
| | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
| `-v` | `--version` | version for toolbox | |
## Examples
@@ -50,11 +51,6 @@ description: >
# Server with prebuilt + custom tools configurations
./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres
# Server with multiple prebuilt tools configurations
./toolbox --prebuilt alloydb-postgres,alloydb-postgres-admin
# OR
./toolbox --prebuilt alloydb-postgres --prebuilt alloydb-postgres-admin
```
### Tool Configuration Sources
@@ -75,7 +71,7 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations
**Prebuilt Configurations:**
- `--prebuilt`: Use one or more predefined configurations for specific database types (e.g.,
- `--prebuilt`: Use predefined configurations for specific database types (e.g.,
'bigquery', 'postgres', 'spanner'). See [Prebuilt Tools
Reference](prebuilt-tools.md) for allowed values.

View File

@@ -16,9 +16,6 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
{{< notice tip >}}
You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or
`--tools-folder` to combine prebuilt configs with custom tools.
You can also combine multiple prebuilt configs.
See [Usage Examples](../reference/cli.md#examples).
{{< /notice >}}

View File

@@ -64,12 +64,14 @@ type ServerConfig struct {
Stdio bool
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
DisableReload bool
// UI indicates if Toolbox UI endpoints (/ui) are available
// UI indicates if Toolbox UI endpoints (/ui) are available.
UI bool
// Specifies a list of origins permitted to access this server.
AllowedOrigins []string
// Specifies a list of hosts permitted to access this server
// Specifies a list of hosts permitted to access this server.
AllowedHosts []string
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
UserAgentMetadata []string
}
type logFormat string

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -64,7 +64,11 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]prompts.Promptset,
error,
) {
ctx = util.WithUserAgent(ctx, cfg.Version)
metadataStr := cfg.Version
if len(cfg.UserAgentMetadata) > 0 {
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
}
ctx = util.WithUserAgent(ctx, metadataStr)
instrumentation, err := util.InstrumentationFromContext(ctx)
if err != nil {
panic(err)

View File

@@ -103,10 +103,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.FHIRFetchPage(ctx, url, tokenStr)
}

View File

@@ -131,9 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
var opts []googleapi.CallOption

View File

@@ -161,9 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
var summary bool

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetDataset(tokenStr)
}

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetDICOMStore(storeID, tokenStr)
}

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetDICOMStoreMetrics(storeID, tokenStr)
}

View File

@@ -130,9 +130,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
}

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetFHIRStore(storeID, tokenStr)
}

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.GetFHIRStoreMetrics(storeID, tokenStr)
}

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.ListDICOMStores(tokenStr)
}

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
return source.ListFHIRStores(tokenStr)
}

View File

@@ -127,9 +127,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
if !ok {

View File

@@ -140,9 +140,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})

View File

@@ -138,9 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})

View File

@@ -133,9 +133,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
var tokenStr string
if source.UseClientAuthorization() {
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
if err != nil {

View File

@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Set up table for semanti search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)

View File

@@ -112,8 +112,7 @@ func TestHealthcareToolEndpoints(t *testing.T) {
fhirStoreID := "fhir-store-" + uuid.New().String()
dicomStoreID := "dicom-store-" + uuid.New().String()
patient1ID, patient2ID, teardown := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
defer teardown(t)
patient1ID, patient2ID := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
toolsFile := getToolsConfig(sourceConfig)
toolsFile = addClientAuthSourceConfig(t, toolsFile)
@@ -173,10 +172,8 @@ func TestHealthcareToolWithStoreRestriction(t *testing.T) {
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
_, _, teardownAllowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
defer teardownAllowedStores(t)
_, _, teardownDisallowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
defer teardownDisallowedStores(t)
setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
// Configure source with dataset restriction.
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
@@ -257,7 +254,7 @@ func newHealthcareService(ctx context.Context) (*healthcare.Service, error) {
return healthcareService, nil
}
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string, func(*testing.T)) {
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string) {
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
var err error
@@ -266,12 +263,24 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
t.Fatalf("failed to create fhir store: %v", err)
}
// Register cleanup
t.Cleanup(func() {
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
t.Logf("failed to delete fhir store: %v", err)
}
})
// Create DICOM store
dicomStore := &healthcare.DicomStore{}
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
t.Fatalf("failed to create dicom store: %v", err)
}
// Register cleanup
t.Cleanup(func() {
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
t.Logf("failed to delete dicom store: %v", err)
}
})
// Create Patient 1
patient1Body := bytes.NewBuffer([]byte(`{
@@ -317,15 +326,7 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
}
teardown := func(t *testing.T) {
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
t.Logf("failed to delete fhir store: %v", err)
}
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
t.Logf("failed to delete dicom store: %v", err)
}
}
return patient1ID, patient2ID, teardown
return patient1ID, patient2ID
}
func getToolsConfig(sourceConfig map[string]any) map[string]any {

View File

@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
}
// Test connection with different IP type

251
tests/embedding.go Normal file
View File

@@ -0,0 +1,251 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tests contains end to end tests meant to verify the Toolbox Server
// works as expected when executed as a binary.
package tests
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/jackc/pgx/v5/pgxpool"
)
var apiKey = os.Getenv("API_KEY")
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
// with configurable tool kind and SQL statements.
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
config["embeddingModels"] = map[string]any{
"gemini_model": map[string]any{
"kind": "gemini",
"model": "gemini-embedding-001",
"apiKey": apiKey,
"dimension": 768,
},
}
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["insert_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Stores content and its vector embedding into the documents table.",
"statement": insertStmt,
"parameters": []any{
map[string]any{
"name": "content",
"type": "string",
"description": "The text content associated with the vector.",
},
map[string]any{
"name": "text_to_embed",
"type": "string",
"description": "The text content used to generate the vector.",
"embeddedBy": "gemini_model",
},
},
}
tools["search_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Finds the most semantically similar document to the query vector.",
"statement": searchStmt,
"parameters": []any{
map[string]any{
"name": "query",
"type": "string",
"description": "The text content to search for.",
"embeddedBy": "gemini_model",
},
},
}
config["tools"] = tools
return config
}
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
// via both HTTP and MCP endpoints and verifies the output.
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
// Initialize MCP session once for the MCP test cases
sessionId := RunInitialize(t, "2024-11-05")
tcs := []struct {
name string
api string
isMcp bool
requestBody interface{}
want string
}{
{
name: "HTTP invoke insert_docs",
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
isMcp: false,
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
want: insertWant,
},
{
name: "HTTP invoke search_docs",
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
isMcp: false,
requestBody: `{"query": "fast fox jumping"}`,
want: searchWant,
},
{
name: "MCP invoke insert_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-insert-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "insert_docs",
"arguments": map[string]any{
"content": "The quick brown fox jumps over the lazy dog",
"text_to_embed": "The quick brown fox jumps over the lazy dog",
},
},
},
want: mcpInsertWant,
},
{
name: "MCP invoke search_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-search-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "search_docs",
"arguments": map[string]any{
"query": "fast fox jumping",
},
},
},
want: searchWant,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
var bodyReader io.Reader
headers := map[string]string{}
// Prepare Request Body and Headers
if tc.isMcp {
reqBytes, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("failed to marshal mcp request: %v", err)
}
bodyReader = bytes.NewBuffer(reqBytes)
if sessionId != "" {
headers["Mcp-Session-Id"] = sessionId
}
} else {
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
}
// Send Request
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
if resp.StatusCode != http.StatusOK {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Normalize Response to get the actual tool result string
var got string
if tc.isMcp {
var mcpResp struct {
Result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
} `json:"result"`
}
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
t.Fatalf("error parsing mcp response: %s", err)
}
if len(mcpResp.Result.Content) > 0 {
got = mcpResp.Result.Content[0].Text
}
} else {
var httpResp map[string]interface{}
if err := json.Unmarshal(respBody, &httpResp); err != nil {
t.Fatalf("error parsing http response: %s", err)
}
if res, ok := httpResp["result"].(string); ok {
got = res
}
}
if !strings.Contains(got, tc.want) {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
// SetupPostgresVectorTable sets up the vector extension and a vector table
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
t.Helper()
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
t.Fatalf("failed to create vector extension: %v", err)
}
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
id SERIAL PRIMARY KEY,
content TEXT,
embedding vector(768)
)`, tableName)
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
t.Fatalf("failed to create table %s: %v", tableName, err)
}
return tableName, func(t *testing.T) {
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
t.Errorf("failed to drop table %s: %v", tableName, err)
}
}
}
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
return insertStmt, searchStmt
}

View File

@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
}

View File

@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
var filteredGot []any
for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
name, _ := tableMap["object_name"].(string)
// Only keep the table if it matches expected test tables
if name == tableNameParam || name == tableNameAuth {
filteredGot = append(filteredGot, item)
}
}