mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-21 05:18:14 -05:00
Compare commits
9 Commits
tool-name-
...
dual
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c60344f0c1 | ||
|
|
2cd8c4692e | ||
|
|
e4f60e5633 | ||
|
|
d7af21bdde | ||
|
|
adc9589766 | ||
|
|
c25a2330fe | ||
|
|
6e09b08c6a | ||
|
|
1f15a111f1 | ||
|
|
dfddeb528d |
@@ -87,7 +87,7 @@ steps:
|
||||
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv:
|
||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
|
||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -134,7 +134,7 @@ steps:
|
||||
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
||||
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
|
||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -293,7 +293,7 @@ steps:
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud Healthcare API" \
|
||||
cloudhealthcare \
|
||||
cloudhealthcare || echo "Integration tests failed."
|
||||
cloudhealthcare
|
||||
|
||||
- id: "postgres"
|
||||
name: golang:1
|
||||
@@ -305,7 +305,7 @@ steps:
|
||||
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
||||
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
|
||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
@@ -964,6 +964,13 @@ steps:
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
# Common secrets
|
||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||
env: CLIENT_ID
|
||||
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
|
||||
env: API_KEY
|
||||
|
||||
# Resource-specific secrets
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
env: CLOUD_SQL_POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
||||
@@ -980,8 +987,6 @@ availableSecrets:
|
||||
env: POSTGRES_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
||||
env: POSTGRES_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||
env: CLIENT_ID
|
||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
||||
env: NEO4J_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
||||
|
||||
@@ -386,6 +386,7 @@ func NewCommand(opts ...Option) *Command {
|
||||
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
||||
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
||||
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
|
||||
122
cmd/root_test.go
122
cmd/root_test.go
@@ -70,6 +70,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
|
||||
if c.AllowedHosts == nil {
|
||||
c.AllowedHosts = []string{"*"}
|
||||
}
|
||||
if c.UserAgentMetadata == nil {
|
||||
c.UserAgentMetadata = []string{}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -230,6 +233,13 @@ func TestServerConfigFlags(t *testing.T) {
|
||||
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "user agent metadata",
|
||||
args: []string{"--user-agent-metadata", "foo,bar"},
|
||||
want: withDefaults(server.ServerConfig{
|
||||
UserAgentMetadata: []string{"foo", "bar"},
|
||||
}),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -2171,3 +2181,115 @@ func TestDefaultToolsFileBehavior(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParameterReferenceValidation(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// Base template
|
||||
baseYaml := `
|
||||
sources:
|
||||
dummy-source:
|
||||
kind: http
|
||||
baseUrl: http://example.com
|
||||
tools:
|
||||
test-tool:
|
||||
kind: postgres-sql
|
||||
source: dummy-source
|
||||
description: test tool
|
||||
statement: SELECT 1;
|
||||
parameters:
|
||||
%s`
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
params string
|
||||
wantErr bool
|
||||
errSubstr string
|
||||
}{
|
||||
{
|
||||
desc: "valid backward reference",
|
||||
params: `
|
||||
- name: source_param
|
||||
type: string
|
||||
description: source
|
||||
- name: copy_param
|
||||
type: string
|
||||
description: copy
|
||||
valueFromParam: source_param`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "valid forward reference (out of order)",
|
||||
params: `
|
||||
- name: copy_param
|
||||
type: string
|
||||
description: copy
|
||||
valueFromParam: source_param
|
||||
- name: source_param
|
||||
type: string
|
||||
description: source`,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid missing reference",
|
||||
params: `
|
||||
- name: copy_param
|
||||
type: string
|
||||
description: copy
|
||||
valueFromParam: non_existent_param`,
|
||||
wantErr: true,
|
||||
errSubstr: "references '\"non_existent_param\"' in the 'valueFromParam' field",
|
||||
},
|
||||
{
|
||||
desc: "invalid self reference",
|
||||
params: `
|
||||
- name: myself
|
||||
type: string
|
||||
description: self
|
||||
valueFromParam: myself`,
|
||||
wantErr: true,
|
||||
errSubstr: "parameter \"myself\" cannot copy value from itself",
|
||||
},
|
||||
{
|
||||
desc: "multiple valid references",
|
||||
params: `
|
||||
- name: a
|
||||
type: string
|
||||
description: a
|
||||
- name: b
|
||||
type: string
|
||||
description: b
|
||||
valueFromParam: a
|
||||
- name: c
|
||||
type: string
|
||||
description: c
|
||||
valueFromParam: a`,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
// Indent parameters to match YAML structure
|
||||
yamlContent := fmt.Sprintf(baseYaml, tc.params)
|
||||
|
||||
_, err := parseToolsFile(ctx, []byte(yamlContent))
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errSubstr) {
|
||||
t.Errorf("error %q does not contain expected substring %q", err.Error(), tc.errSubstr)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +207,7 @@ You can connect to Toolbox Cloud Run instances directly through the SDK.
|
||||
{{< tab header="Python" lang="python" >}}
|
||||
import asyncio
|
||||
from toolbox_core import ToolboxClient, auth_methods
|
||||
from toolbox_core.protocol import Protocol
|
||||
|
||||
# Replace with the Cloud Run service URL generated in the previous step
|
||||
URL = "https://cloud-run-url.app"
|
||||
@@ -217,6 +218,7 @@ async def main():
|
||||
async with ToolboxClient(
|
||||
URL,
|
||||
client_headers={"Authorization": auth_token_provider},
|
||||
protocol=Protocol.TOOLBOX,
|
||||
) as toolbox:
|
||||
toolset = await toolbox.load_toolset()
|
||||
# ...
|
||||
@@ -281,3 +283,5 @@ contain the specific error message needed to diagnose the problem.
|
||||
Manager, it means the Toolbox service account is missing permissions.
|
||||
- Ensure the `toolbox-identity` service account has the **Secret Manager
|
||||
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
|
||||
|
||||
- **Cloud Run Connections via IAP:** Currently we do not support Cloud Run connections via [IAP](https://docs.cloud.google.com/iap/docs/concepts-overview). Please disable IAP if you are using it.
|
||||
@@ -27,6 +27,7 @@ description: >
|
||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
|
||||
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
|
||||
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
|
||||
| `-v` | `--version` | version for toolbox | |
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -3,13 +3,14 @@ title: "EmbeddingModels"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
EmbeddingModels represent services that transform text into vector embeddings for semantic search.
|
||||
EmbeddingModels represent services that transform text into vector embeddings
|
||||
for semantic search.
|
||||
---
|
||||
|
||||
EmbeddingModels represent services that generate vector representations of text
|
||||
data. In the MCP Toolbox, these models enable **Semantic Queries**,
|
||||
allowing [Tools](../tools/) to automatically convert human-readable text into
|
||||
numerical vectors before using them in a query.
|
||||
data. In the MCP Toolbox, these models enable **Semantic Queries**, allowing
|
||||
[Tools](../tools/) to automatically convert human-readable text into numerical
|
||||
vectors before using them in a query.
|
||||
|
||||
This is primarily used in two scenarios:
|
||||
|
||||
@@ -19,14 +20,33 @@ This is primarily used in two scenarios:
|
||||
- **Semantic Search**: Converting a natural language query into a vector to
|
||||
perform similarity searches.
|
||||
|
||||
## Hidden Parameter Duplication (valueFromParam)
|
||||
|
||||
When building tools for vector ingestion, you often need the same input string
|
||||
twice:
|
||||
|
||||
1. To store the original text in a TEXT column.
|
||||
1. To generate the vector embedding for a VECTOR column.
|
||||
|
||||
Requesting an Agent (LLM) to output the exact same string twice is inefficient
|
||||
and error-prone. The `valueFromParam` field solves this by allowing a parameter
|
||||
to inherit its value from another parameter in the same tool.
|
||||
|
||||
### Key Behaviors
|
||||
|
||||
1. Hidden from Manifest: Parameters with valueFromParam set are excluded from
|
||||
the tool definition sent to the Agent. The Agent does not know this parameter
|
||||
exists.
|
||||
1. Auto-Filled: When the tool is executed, the Toolbox automatically copies the
|
||||
value from the referenced parameter before processing embeddings.
|
||||
|
||||
## Example
|
||||
|
||||
The following configuration defines an embedding model and applies it to
|
||||
specific tool parameters.
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your API keys into the configuration file.
|
||||
{{< notice tip >}} Use environment variable replacement with the format
|
||||
${ENV_NAME} instead of hardcoding your API keys into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
### Step 1 - Define an Embedding Model
|
||||
@@ -40,14 +60,12 @@ embeddingModels:
|
||||
model: gemini-embedding-001
|
||||
apiKey: ${GOOGLE_API_KEY}
|
||||
dimension: 768
|
||||
|
||||
```
|
||||
|
||||
### Step 2 - Embed Tool Parameters
|
||||
|
||||
Use the defined embedding model, embed your query parameters using the
|
||||
`embeddedBy` field. Only string-typed
|
||||
parameters can be embedded:
|
||||
`embeddedBy` field. Only string-typed parameters can be embedded:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
@@ -61,10 +79,13 @@ tools:
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
description: The raw text content to be stored in the database.
|
||||
- name: vector_string
|
||||
type: string
|
||||
description: The text to be vectorized and stored.
|
||||
embeddedBy: gemini-model # refers to the name of a defined embedding model
|
||||
# This parameter is hidden from the LLM.
|
||||
# It automatically copies the value from 'content' and embeds it.
|
||||
valueFromParam: content
|
||||
embeddedBy: gemini-model
|
||||
|
||||
# Semantic search tool
|
||||
search_embedding:
|
||||
|
||||
@@ -12,6 +12,9 @@ aliases:
|
||||
|
||||
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
||||
|
||||
> [!NOTE]
|
||||
> Only `alloydb`, `spannerReference`, and `cloudSqlReference` are supported as [datasource references](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1beta/projects.locations.dataAgents#DatasourceReferences).
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -30,6 +30,10 @@ following config for example:
|
||||
- name: userNames
|
||||
type: array
|
||||
description: The user names to be set.
|
||||
items:
|
||||
name: userName # the item name doesn't matter but it has to exist
|
||||
type: string
|
||||
description: username
|
||||
```
|
||||
|
||||
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command
|
||||
|
||||
@@ -64,12 +64,14 @@ type ServerConfig struct {
|
||||
Stdio bool
|
||||
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
||||
DisableReload bool
|
||||
// UI indicates if Toolbox UI endpoints (/ui) are available
|
||||
// UI indicates if Toolbox UI endpoints (/ui) are available.
|
||||
UI bool
|
||||
// Specifies a list of origins permitted to access this server.
|
||||
AllowedOrigins []string
|
||||
// Specifies a list of hosts permitted to access this server
|
||||
// Specifies a list of hosts permitted to access this server.
|
||||
AllowedHosts []string
|
||||
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
|
||||
UserAgentMetadata []string
|
||||
}
|
||||
|
||||
type logFormat string
|
||||
@@ -294,6 +296,43 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
|
||||
return fmt.Errorf("invalid 'kind' field for tool %q (must be a string)", name)
|
||||
}
|
||||
|
||||
// validify parameter references
|
||||
if rawParams, ok := v["parameters"]; ok {
|
||||
if paramsList, ok := rawParams.([]any); ok {
|
||||
// Turn params into a map
|
||||
validParamNames := make(map[string]bool)
|
||||
for _, rawP := range paramsList {
|
||||
if pMap, ok := rawP.(map[string]any); ok {
|
||||
if pName, ok := pMap["name"].(string); ok && pName != "" {
|
||||
validParamNames[pName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate references
|
||||
for i, rawP := range paramsList {
|
||||
pMap, ok := rawP.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
pName, _ := pMap["name"].(string)
|
||||
refName, _ := pMap["valueFromParam"].(string)
|
||||
|
||||
if refName != "" {
|
||||
// Check if the referenced parameter exists
|
||||
if !validParamNames[refName] {
|
||||
return fmt.Errorf("tool %q config error: parameter %q (index %d) references '%q' in the 'valueFromParam' field, which is not a defined parameter.", name, pName, i, refName)
|
||||
}
|
||||
|
||||
// Check for self-reference
|
||||
if refName == pName {
|
||||
return fmt.Errorf("tool %q config error: parameter %q cannot copy value from itself", name, pName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
yamlDecoder, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating YAML decoder for tool %q: %w", name, err)
|
||||
|
||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
|
||||
@@ -158,3 +158,4 @@ func (r *ResourceManager) GetPromptsMap() map[string]prompts.Prompt {
|
||||
}
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,11 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]prompts.Promptset,
|
||||
error,
|
||||
) {
|
||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
||||
metadataStr := cfg.Version
|
||||
if len(cfg.UserAgentMetadata) > 0 {
|
||||
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, metadataStr)
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -103,10 +103,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
}
|
||||
|
||||
@@ -131,9 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var opts []googleapi.CallOption
|
||||
|
||||
@@ -161,9 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var summary bool
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDataset(tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDICOMStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -130,9 +130,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.ListDICOMStores(tokenStr)
|
||||
}
|
||||
|
||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
return source.ListFHIRStores(tokenStr)
|
||||
}
|
||||
|
||||
@@ -127,9 +127,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -140,9 +140,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
|
||||
@@ -138,9 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
|
||||
@@ -133,9 +133,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||
if err != nil {
|
||||
|
||||
@@ -134,7 +134,12 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
|
||||
var err error
|
||||
paramAuthServices := p.GetAuthServices()
|
||||
name := p.GetName()
|
||||
if len(paramAuthServices) == 0 {
|
||||
|
||||
sourceParamName := p.GetValueFromParam()
|
||||
if sourceParamName != "" {
|
||||
v, _ = data[sourceParamName]
|
||||
|
||||
} else if len(paramAuthServices) == 0 {
|
||||
// parse non auth-required parameter
|
||||
var ok bool
|
||||
v, ok = data[name]
|
||||
@@ -318,6 +323,7 @@ type Parameter interface {
|
||||
GetRequired() bool
|
||||
GetAuthServices() []ParamAuthService
|
||||
GetEmbeddedBy() string
|
||||
GetValueFromParam() string
|
||||
Parse(any) (any, error)
|
||||
Manifest() ParameterManifest
|
||||
McpManifest() (ParameterMcpManifest, []string)
|
||||
@@ -465,6 +471,9 @@ func ParseParameter(ctx context.Context, p map[string]any, paramType string) (Pa
|
||||
func (ps Parameters) Manifest() []ParameterManifest {
|
||||
rtn := make([]ParameterManifest, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
if p.GetValueFromParam() != "" {
|
||||
continue
|
||||
}
|
||||
rtn = append(rtn, p.Manifest())
|
||||
}
|
||||
return rtn
|
||||
@@ -476,6 +485,11 @@ func (ps Parameters) McpManifest() (McpToolsSchema, map[string][]string) {
|
||||
authParam := make(map[string][]string)
|
||||
|
||||
for _, p := range ps {
|
||||
// If the parameter is sourced from another param, skip it in the MCP manifest
|
||||
if p.GetValueFromParam() != "" {
|
||||
continue
|
||||
}
|
||||
|
||||
name := p.GetName()
|
||||
paramManifest, authParamList := p.McpManifest()
|
||||
defaultV := p.GetDefault()
|
||||
@@ -509,6 +523,7 @@ type ParameterManifest struct {
|
||||
Default any `json:"default,omitempty"`
|
||||
AdditionalProperties any `json:"additionalProperties,omitempty"`
|
||||
EmbeddedBy string `json:"embeddedBy,omitempty"`
|
||||
ValueFromParam string `json:"valueFromParam,omitempty"`
|
||||
}
|
||||
|
||||
// ParameterMcpManifest represents properties when served as part of a ToolMcpManifest.
|
||||
@@ -531,6 +546,7 @@ type CommonParameter struct {
|
||||
AuthServices []ParamAuthService `yaml:"authServices"`
|
||||
AuthSources []ParamAuthService `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
EmbeddedBy string `yaml:"embeddedBy"`
|
||||
ValueFromParam string `yaml:"valueFromParam"`
|
||||
}
|
||||
|
||||
// GetName returns the name specified for the Parameter.
|
||||
@@ -588,10 +604,16 @@ func (p *CommonParameter) IsExcludedValues(v any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetEmbeddedBy returns the embedding model name for the Parameter.
|
||||
func (p *CommonParameter) GetEmbeddedBy() string {
|
||||
return p.EmbeddedBy
|
||||
}
|
||||
|
||||
// GetValueFromParam returns the param value to copy from.
|
||||
func (p *CommonParameter) GetValueFromParam() string {
|
||||
return p.ValueFromParam
|
||||
}
|
||||
|
||||
// MatchStringOrRegex checks if the input matches the target
|
||||
func MatchStringOrRegex(input, target any) bool {
|
||||
targetS, ok := target.(string)
|
||||
|
||||
@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semanti search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
|
||||
@@ -112,8 +112,7 @@ func TestHealthcareToolEndpoints(t *testing.T) {
|
||||
fhirStoreID := "fhir-store-" + uuid.New().String()
|
||||
dicomStoreID := "dicom-store-" + uuid.New().String()
|
||||
|
||||
patient1ID, patient2ID, teardown := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
||||
defer teardown(t)
|
||||
patient1ID, patient2ID := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
||||
|
||||
toolsFile := getToolsConfig(sourceConfig)
|
||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||
@@ -173,10 +172,8 @@ func TestHealthcareToolWithStoreRestriction(t *testing.T) {
|
||||
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
|
||||
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
|
||||
|
||||
_, _, teardownAllowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
||||
defer teardownAllowedStores(t)
|
||||
_, _, teardownDisallowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
||||
defer teardownDisallowedStores(t)
|
||||
setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
||||
setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
||||
|
||||
// Configure source with dataset restriction.
|
||||
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
|
||||
@@ -257,7 +254,7 @@ func newHealthcareService(ctx context.Context) (*healthcare.Service, error) {
|
||||
return healthcareService, nil
|
||||
}
|
||||
|
||||
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string, func(*testing.T)) {
|
||||
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string) {
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
|
||||
var err error
|
||||
|
||||
@@ -266,12 +263,24 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
||||
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
|
||||
t.Fatalf("failed to create fhir store: %v", err)
|
||||
}
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete fhir store: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create DICOM store
|
||||
dicomStore := &healthcare.DicomStore{}
|
||||
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
|
||||
t.Fatalf("failed to create dicom store: %v", err)
|
||||
}
|
||||
// Register cleanup
|
||||
t.Cleanup(func() {
|
||||
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete dicom store: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create Patient 1
|
||||
patient1Body := bytes.NewBuffer([]byte(`{
|
||||
@@ -317,15 +326,7 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
||||
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
|
||||
}
|
||||
|
||||
teardown := func(t *testing.T) {
|
||||
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete fhir store: %v", err)
|
||||
}
|
||||
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
||||
t.Logf("failed to delete dicom store: %v", err)
|
||||
}
|
||||
}
|
||||
return patient1ID, patient2ID, teardown
|
||||
return patient1ID, patient2ID
|
||||
}
|
||||
|
||||
func getToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
|
||||
@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semantic search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
251
tests/embedding.go
Normal file
251
tests/embedding.go
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package tests contains end to end tests meant to verify the Toolbox Server
|
||||
// works as expected when executed as a binary.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var apiKey = os.Getenv("API_KEY")
|
||||
|
||||
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
|
||||
// with configurable tool kind and SQL statements.
|
||||
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
|
||||
config["embeddingModels"] = map[string]any{
|
||||
"gemini_model": map[string]any{
|
||||
"kind": "gemini",
|
||||
"model": "gemini-embedding-001",
|
||||
"apiKey": apiKey,
|
||||
"dimension": 768,
|
||||
},
|
||||
}
|
||||
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
|
||||
tools["insert_docs"] = map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Stores content and its vector embedding into the documents table.",
|
||||
"statement": insertStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "content",
|
||||
"type": "string",
|
||||
"description": "The text content associated with the vector.",
|
||||
},
|
||||
map[string]any{
|
||||
"name": "text_to_embed",
|
||||
"type": "string",
|
||||
"description": "The text content used to generate the vector.",
|
||||
"embeddedBy": "gemini_model",
|
||||
"valueFromParam": "content",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tools["search_docs"] = map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Finds the most semantically similar document to the query vector.",
|
||||
"statement": searchStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "The text content to search for.",
|
||||
"embeddedBy": "gemini_model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
|
||||
// via both HTTP and MCP endpoints and verifies the output.
|
||||
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
|
||||
// Initialize MCP session once for the MCP test cases
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
isMcp bool
|
||||
requestBody interface{}
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "HTTP invoke insert_docs",
|
||||
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
|
||||
isMcp: false,
|
||||
requestBody: `{"content": "The quick brown fox jumps over the lazy dog"`,
|
||||
want: insertWant,
|
||||
},
|
||||
{
|
||||
name: "HTTP invoke search_docs",
|
||||
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
|
||||
isMcp: false,
|
||||
requestBody: `{"query": "fast fox jumping"}`,
|
||||
want: searchWant,
|
||||
},
|
||||
{
|
||||
name: "MCP invoke insert_docs",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
isMcp: true,
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "mcp-insert-docs",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "insert_docs",
|
||||
"arguments": map[string]any{
|
||||
"content": "The quick brown fox jumps over the lazy dog",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: mcpInsertWant,
|
||||
},
|
||||
{
|
||||
name: "MCP invoke search_docs",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
isMcp: true,
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "mcp-search-docs",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "search_docs",
|
||||
"arguments": map[string]any{
|
||||
"query": "fast fox jumping",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: searchWant,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var bodyReader io.Reader
|
||||
headers := map[string]string{}
|
||||
|
||||
// Prepare Request Body and Headers
|
||||
if tc.isMcp {
|
||||
reqBytes, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal mcp request: %v", err)
|
||||
}
|
||||
bodyReader = bytes.NewBuffer(reqBytes)
|
||||
if sessionId != "" {
|
||||
headers["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
} else {
|
||||
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
|
||||
}
|
||||
|
||||
// Send Request
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Normalize Response to get the actual tool result string
|
||||
var got string
|
||||
if tc.isMcp {
|
||||
var mcpResp struct {
|
||||
Result struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
|
||||
t.Fatalf("error parsing mcp response: %s", err)
|
||||
}
|
||||
if len(mcpResp.Result.Content) > 0 {
|
||||
got = mcpResp.Result.Content[0].Text
|
||||
}
|
||||
} else {
|
||||
var httpResp map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &httpResp); err != nil {
|
||||
t.Fatalf("error parsing http response: %s", err)
|
||||
}
|
||||
if res, ok := httpResp["result"].(string); ok {
|
||||
got = res
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetupPostgresVectorTable sets up the vector extension and a vector table
|
||||
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
|
||||
t.Helper()
|
||||
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
|
||||
t.Fatalf("failed to create vector extension: %v", err)
|
||||
}
|
||||
|
||||
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
content TEXT,
|
||||
embedding vector(768)
|
||||
)`, tableName)
|
||||
|
||||
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
||||
t.Fatalf("failed to create table %s: %v", tableName, err)
|
||||
}
|
||||
|
||||
return tableName, func(t *testing.T) {
|
||||
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
|
||||
t.Errorf("failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
|
||||
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
|
||||
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
|
||||
return insertStmt, searchStmt
|
||||
}
|
||||
@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Set up table for semantic search
|
||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||
defer tearDownVectorTable(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
|
||||
// Add semantic search tool config
|
||||
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||
}
|
||||
|
||||
@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
||||
name, _ := tableMap["object_name"].(string)
|
||||
|
||||
// Only keep the table if it matches expected test tables
|
||||
if name == tableNameParam || name == tableNameAuth {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user