mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: Add AuthRequired to Tool Manifest (#433)
Add `AuthRequired` to Tool Manifest so SDK could throw an error early for unauthorized Tool invocations. SDK changes: https://github.com/googleapis/mcp-toolbox-sdk-python/pull/72/files Also added `authRequired` to Neo4j and dgraph tools.
This commit is contained in:
124
cmd/root_test.go
124
cmd/root_test.go
@@ -327,6 +327,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
Toolsets: server.ToolsetConfigs{
|
||||
@@ -449,11 +450,12 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
||||
@@ -547,11 +549,113 @@ func TestParseToolFileWithAuth(t *testing.T) {
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
||||
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
Toolsets: server.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "basic example with authRequired",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
authServices:
|
||||
my-google-service:
|
||||
kind: google
|
||||
clientId: my-client-id
|
||||
other-google-service:
|
||||
kind: google
|
||||
clientId: other-client-id
|
||||
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
- name: id
|
||||
type: integer
|
||||
description: user id
|
||||
authServices:
|
||||
- name: my-google-service
|
||||
field: user_id
|
||||
- name: email
|
||||
type: string
|
||||
description: user email
|
||||
authServices:
|
||||
- name: my-google-service
|
||||
field: email
|
||||
- name: other-google-service
|
||||
field: other_email
|
||||
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
`,
|
||||
wantToolsFile: ToolsFile{
|
||||
Sources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
AuthServices: server.AuthServiceConfigs{
|
||||
"my-google-service": google.Config{
|
||||
Name: "my-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "my-client-id",
|
||||
},
|
||||
"other-google-service": google.Config{
|
||||
Name: "other-google-service",
|
||||
Kind: google.AuthServiceKind,
|
||||
ClientID: "other-client-id",
|
||||
},
|
||||
},
|
||||
Tools: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: postgressql.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
||||
|
||||
@@ -316,6 +316,11 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if v["authRequired"] == nil {
|
||||
v["authRequired"] = []string{}
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for %q", name)
|
||||
|
||||
@@ -118,7 +118,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
NLConfig: cfg.NLConfig,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -52,11 +52,12 @@ func TestParseFromYamlSpanner(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigquery.Config{
|
||||
Name: "example_tool",
|
||||
Kind: bigquery.ToolKind,
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: bigquery.ToolKind,
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
|
||||
@@ -79,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -52,11 +52,12 @@ func TestParseFromYamlBigtable(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigtable.Config{
|
||||
Name: "example_tool",
|
||||
Kind: bigtable.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: bigtable.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
|
||||
@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
DgraphClient: s.DgraphClient(),
|
||||
IsQuery: cfg.IsQuery,
|
||||
Timeout: cfg.Timeout,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -42,9 +42,6 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
kind: dgraph-dql
|
||||
source: my-dgraph-instance
|
||||
description: some tool description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
isQuery: true
|
||||
timeout: 20s
|
||||
statement: |
|
||||
@@ -55,7 +52,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
Name: "example_tool",
|
||||
Kind: dgraph.ToolKind,
|
||||
Source: "my-dgraph-instance",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
AuthRequired: []string{},
|
||||
Description: "some tool description",
|
||||
IsQuery: true,
|
||||
Timeout: "20s",
|
||||
@@ -76,11 +73,12 @@ func TestParseFromYamlDgraph(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": dgraph.Config{
|
||||
Name: "example_tool",
|
||||
Kind: dgraph.ToolKind,
|
||||
Source: "my-dgraph-instance",
|
||||
Description: "some tool description",
|
||||
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
|
||||
Name: "example_tool",
|
||||
Kind: dgraph.ToolKind,
|
||||
Source: "my-dgraph-instance",
|
||||
Description: "some tool description",
|
||||
AuthRequired: []string{},
|
||||
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -161,7 +161,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Headers: combinedHeaders,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.MSSQLDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -81,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.MySQLPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -82,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Driver: s.Neo4jDriver(),
|
||||
Database: s.Neo4jDatabase(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.SpannerClient(),
|
||||
dialect: s.DatabaseDialect(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -52,11 +52,12 @@ func TestParseFromYamlSpanner(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": spanner.Config{
|
||||
Name: "example_tool",
|
||||
Kind: spanner.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Name: "example_tool",
|
||||
Kind: spanner.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
|
||||
@@ -79,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.SQLiteDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
|
||||
@@ -36,8 +36,9 @@ type Tool interface {
|
||||
|
||||
// Manifest is the representation of tools sent to Client SDKs.
|
||||
type Manifest struct {
|
||||
Description string `json:"description"`
|
||||
Parameters []ParameterManifest `json:"parameters"`
|
||||
Description string `json:"description"`
|
||||
Parameters []ParameterManifest `json:"parameters"`
|
||||
AuthRequired []string `json:"authRequired"`
|
||||
}
|
||||
|
||||
// Definition for a tool the MCP client can call.
|
||||
|
||||
@@ -123,6 +123,7 @@ func runAiNlToolGetTest(t *testing.T) {
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -96,8 +96,9 @@ func TestDgraphToolEndpoints(t *testing.T) {
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-dql-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -105,8 +105,9 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-cypher-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -263,8 +263,9 @@ func RunToolGetTest(t *testing.T) {
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user