feat: Add AuthRequired to Tool Manifest (#433)

Add `AuthRequired` to Tool Manifest so SDK could throw an error early
for unauthorized Tool invocations.
SDK changes:
https://github.com/googleapis/mcp-toolbox-sdk-python/pull/72/files

Also added `authRequired` to Neo4j and dgraph tools.
This commit is contained in:
Wenxin Du
2025-04-23 12:52:04 -04:00
committed by GitHub
parent 8014eea033
commit d9388ad57e
22 changed files with 168 additions and 53 deletions

View File

@@ -327,6 +327,7 @@ func TestParseToolFile(t *testing.T) {
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
},
AuthRequired: []string{},
},
},
Toolsets: server.ToolsetConfigs{
@@ -449,11 +450,12 @@ func TestParseToolFileWithAuth(t *testing.T) {
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
@@ -547,11 +549,113 @@ func TestParseToolFileWithAuth(t *testing.T) {
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
},
},
},
Toolsets: server.ToolsetConfigs{
"example_toolset": tools.ToolsetConfig{
Name: "example_toolset",
ToolNames: []string{"example_tool"},
},
},
},
},
{
description: "basic example with authRequired",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
authServices:
my-google-service:
kind: google
clientId: my-client-id
other-google-service:
kind: google
clientId: other-client-id
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
authRequired:
- my-google-service
parameters:
- name: country
type: string
description: some description
- name: id
type: integer
description: user id
authServices:
- name: my-google-service
field: user_id
- name: email
type: string
description: user email
authServices:
- name: my-google-service
field: email
- name: other-google-service
field: other_email
toolsets:
example_toolset:
- example_tool
`,
wantToolsFile: ToolsFile{
Sources: server.SourceConfigs{
"my-pg-instance": cloudsqlpgsrc.Config{
Name: "my-pg-instance",
Kind: cloudsqlpgsrc.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
User: "my_user",
Password: "my_pass",
},
},
AuthServices: server.AuthServiceConfigs{
"my-google-service": google.Config{
Name: "my-google-service",
Kind: google.AuthServiceKind,
ClientID: "my-client-id",
},
"other-google-service": google.Config{
Name: "other-google-service",
Kind: google.AuthServiceKind,
ClientID: "other-client-id",
},
},
Tools: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: postgressql.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{"my-google-service"},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),

View File

@@ -316,6 +316,11 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
}
// Make `authRequired` an empty list instead of nil for Tool manifest
if v["authRequired"] == nil {
v["authRequired"] = []string{}
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for %q", name)

View File

@@ -118,7 +118,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
NLConfig: cfg.NLConfig,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}

View File

@@ -81,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -52,11 +52,12 @@ func TestParseFromYamlSpanner(t *testing.T) {
`,
want: server.ToolConfigs{
"example_tool": bigquery.Config{
Name: "example_tool",
Kind: bigquery.ToolKind,
Source: "my-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: bigquery.ToolKind,
Source: "my-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
},

View File

@@ -79,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Client: s.BigtableClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -52,11 +52,12 @@ func TestParseFromYamlBigtable(t *testing.T) {
`,
want: server.ToolConfigs{
"example_tool": bigtable.Config{
Name: "example_tool",
Kind: bigtable.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: bigtable.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
},

View File

@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
DgraphClient: s.DgraphClient(),
IsQuery: cfg.IsQuery,
Timeout: cfg.Timeout,
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -42,9 +42,6 @@ func TestParseFromYamlDgraph(t *testing.T) {
kind: dgraph-dql
source: my-dgraph-instance
description: some tool description
authRequired:
- my-google-auth-service
- other-auth-service
isQuery: true
timeout: 20s
statement: |
@@ -55,7 +52,7 @@ func TestParseFromYamlDgraph(t *testing.T) {
Name: "example_tool",
Kind: dgraph.ToolKind,
Source: "my-dgraph-instance",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
AuthRequired: []string{},
Description: "some tool description",
IsQuery: true,
Timeout: "20s",
@@ -76,11 +73,12 @@ func TestParseFromYamlDgraph(t *testing.T) {
`,
want: server.ToolConfigs{
"example_tool": dgraph.Config{
Name: "example_tool",
Kind: dgraph.ToolKind,
Source: "my-dgraph-instance",
Description: "some tool description",
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
Name: "example_tool",
Kind: dgraph.ToolKind,
Source: "my-dgraph-instance",
Description: "some tool description",
AuthRequired: []string{},
Statement: "mutation {set { _:a <name> \"a@email.com\" . _:b <email> \"b@email.com\" .}}\n",
},
},
},

View File

@@ -161,7 +161,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Headers: combinedHeaders,
Client: s.Client,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}

View File

@@ -82,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Db: s.MSSQLDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -81,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.MySQLPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -82,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
Driver: s.Neo4jDriver(),
Database: s.Neo4jDatabase(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -83,7 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
Client: s.SpannerClient(),
dialect: s.DatabaseDialect(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -52,11 +52,12 @@ func TestParseFromYamlSpanner(t *testing.T) {
`,
want: server.ToolConfigs{
"example_tool": spanner.Config{
Name: "example_tool",
Kind: spanner.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Name: "example_tool",
Kind: spanner.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
},

View File

@@ -79,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Db: s.SQLiteDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil

View File

@@ -36,8 +36,9 @@ type Tool interface {
// Manifest is the representation of tools sent to Client SDKs.
type Manifest struct {
Description string `json:"description"`
Parameters []ParameterManifest `json:"parameters"`
Description string `json:"description"`
Parameters []ParameterManifest `json:"parameters"`
AuthRequired []string `json:"authRequired"`
}
// Definition for a tool the MCP client can call.

View File

@@ -123,6 +123,7 @@ func runAiNlToolGetTest(t *testing.T) {
"authSources": []any{},
},
},
"authRequired": []any{},
},
},
},

View File

@@ -96,8 +96,9 @@ func TestDgraphToolEndpoints(t *testing.T) {
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/",
want: map[string]any{
"my-simple-dql-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"authRequired": []any{},
},
},
},

View File

@@ -105,8 +105,9 @@ func TestNeo4jToolEndpoints(t *testing.T) {
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/",
want: map[string]any{
"my-simple-cypher-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"authRequired": []any{},
},
},
},

View File

@@ -263,8 +263,9 @@ func RunToolGetTest(t *testing.T) {
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
want: map[string]any{
"my-simple-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
"authRequired": []any{},
},
},
},