mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat: adding support for Model Context Protocol (MCP). (#396)
Adding Toolbox support for MCP. Toolbox can now be run as an MCP server. Fixes #312. --------- Co-authored-by: Jack Wotherspoon <jackwoth@google.com> Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -80,9 +80,5 @@ with clients in their preferred language. As Gen AI matures, we want developers
|
||||
|
||||
## Is Toolbox compatible with Model Context Protocol (MCP)?
|
||||
|
||||
Toolbox currently uses it's own custom protocol for server-client communication.
|
||||
[Anthropic's Model Context Protocol (MCP)](https://modelcontextprotocol.io/)
|
||||
launched towards the end of Toolbox's development, and is currently missing
|
||||
functionality to support some of our features. We're currently exploring how
|
||||
best to bring Toolbox's functionality to the wider MCP ecosystem.
|
||||
|
||||
Yes! Toolbox is compatible with [Anthropic's Model Context Protocol (MCP)](https://modelcontextprotocol.io/). Please checkout [Connect via MCP](../how-to/connect_via_mcp.md) on how to
|
||||
connect to Toolbox with an MCP client.
|
||||
@@ -81,18 +81,22 @@ A metric is a measurement of a service captured at runtime. The collected data
|
||||
can be used to provide important insights into the service. Toolbox provides the
|
||||
following custom metrics:
|
||||
|
||||
| **Metric Name** | **Description** |
|
||||
|------------------------------------|-------------------------------------------------------|
|
||||
| `toolbox.server.toolset.get.count` | Counts the number of toolset manifest requests served |
|
||||
| `toolbox.server.tool.get.count` | Counts the number of tool manifest requests served |
|
||||
| `toolbox.server.tool.get.invoke` | Counts the number of tool invocation requests served |
|
||||
| **Metric Name** | **Description** |
|
||||
|------------------------------------|---------------------------------------------------------|
|
||||
| `toolbox.server.toolset.get.count` | Counts the number of toolset manifest requests served |
|
||||
| `toolbox.server.tool.get.count` | Counts the number of tool manifest requests served |
|
||||
| `toolbox.server.tool.get.invoke` | Counts the number of tool invocation requests served |
|
||||
| `toolbox.server.mcp.sse.count` | Counts the number of mcp sse connection requests served |
|
||||
| `toolbox.server.mcp.post.count` | Counts the number of mcp post requests served |
|
||||
|
||||
All custom metrics have the following attributes/labels:
|
||||
|
||||
| **Metric Attributes** | **Description** |
|
||||
|-----------------------|-----------------------------------------------------------|
|
||||
| `toolbox.name` | Name of the toolset or tool, if applicable. |
|
||||
| `toolbox.status` | Operation status code, for example: `success`, `failure`. |
|
||||
| **Metric Attributes** | **Description** |
|
||||
|----------------------------|-----------------------------------------------------------|
|
||||
| `toolbox.name` | Name of the toolset or tool, if applicable. |
|
||||
| `toolbox.operation.status` | Operation status code, for example: `success`, `failure`. |
|
||||
| `toolbox.sse.sessionId` | Session id for sse connection, if applicable. |
|
||||
| `toolbox.method` | Method of JSON-RPC request, if applicable. |
|
||||
|
||||
### Traces
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
title: "Configuration"
|
||||
type: docs
|
||||
weight: 3
|
||||
weight: 4
|
||||
description: How to configure Toolbox's tools.yaml file.
|
||||
---
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Quickstart"
|
||||
title: "Quickstart (Local)"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
|
||||
228
docs/en/getting-started/mcp_quickstart/_index.md
Normal file
228
docs/en/getting-started/mcp_quickstart/_index.md
Normal file
@@ -0,0 +1,228 @@
|
||||
---
|
||||
title: "Quickstart (MCP)"
|
||||
type: docs
|
||||
weight: 3
|
||||
description: >
|
||||
How to get started running Toolbox locally with MCP Inspector.
|
||||
---
|
||||
|
||||
## Overview
|
||||
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol
|
||||
that standardizes how applications provide context to LLMs. Check out this page
|
||||
on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
|
||||
|
||||
## Step 1: Set up your database
|
||||
|
||||
In this section, we will create a database, insert some data that needs to be
|
||||
access by our agent, and create a database user for Toolbox to connect with.
|
||||
|
||||
1. Connect to postgres using the `psql` command:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U postgres
|
||||
```
|
||||
|
||||
Here, `postgres` denotes the default postgres superuser.
|
||||
|
||||
1. Create a new database and a new user:
|
||||
|
||||
{{< notice tip >}}
|
||||
For a real application, it's best to follow the principle of least permission
|
||||
and only grant the privileges your application needs.
|
||||
{{< /notice >}}
|
||||
|
||||
```sql
|
||||
CREATE USER toolbox_user WITH PASSWORD 'my-password';
|
||||
|
||||
CREATE DATABASE toolbox_db;
|
||||
GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user;
|
||||
|
||||
ALTER DATABASE toolbox_db OWNER TO toolbox_user;
|
||||
```
|
||||
|
||||
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
1. Connect to your database with your new user:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U toolbox_user -d toolbox_db
|
||||
```
|
||||
|
||||
1. Create a table using the following command:
|
||||
|
||||
```sql
|
||||
CREATE TABLE hotels(
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
name VARCHAR NOT NULL,
|
||||
location VARCHAR NOT NULL,
|
||||
price_tier VARCHAR NOT NULL,
|
||||
checkin_date DATE NOT NULL,
|
||||
checkout_date DATE NOT NULL,
|
||||
booked BIT NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
1. Insert data into the table.
|
||||
|
||||
```sql
|
||||
INSERT INTO hotels(id, name, location, price_tier, checkin_date, checkout_date, booked)
|
||||
VALUES
|
||||
(1, 'Hilton Basel', 'Basel', 'Luxury', '2024-04-22', '2024-04-20', B'0'),
|
||||
(2, 'Marriott Zurich', 'Zurich', 'Upscale', '2024-04-14', '2024-04-21', B'0'),
|
||||
(3, 'Hyatt Regency Basel', 'Basel', 'Upper Upscale', '2024-04-02', '2024-04-20', B'0'),
|
||||
(4, 'Radisson Blu Lucerne', 'Lucerne', 'Midscale', '2024-04-24', '2024-04-05', B'0'),
|
||||
(5, 'Best Western Bern', 'Bern', 'Upper Midscale', '2024-04-23', '2024-04-01', B'0'),
|
||||
(6, 'InterContinental Geneva', 'Geneva', 'Luxury', '2024-04-23', '2024-04-28', B'0'),
|
||||
(7, 'Sheraton Zurich', 'Zurich', 'Upper Upscale', '2024-04-27', '2024-04-02', B'0'),
|
||||
(8, 'Holiday Inn Basel', 'Basel', 'Upper Midscale', '2024-04-24', '2024-04-09', B'0'),
|
||||
(9, 'Courtyard Zurich', 'Zurich', 'Upscale', '2024-04-03', '2024-04-13', B'0'),
|
||||
(10, 'Comfort Inn Bern', 'Bern', 'Midscale', '2024-04-04', '2024-04-16', B'0');
|
||||
```
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
## Step 2: Install and configure Toolbox
|
||||
|
||||
In this section, we will download Toolbox, configure our tools in a
|
||||
`tools.yaml`, and then run the Toolbox server.
|
||||
|
||||
1. Download the latest version of Toolbox as a binary:
|
||||
|
||||
{{< notice tip >}}
|
||||
Select the
|
||||
[correct binary](https://github.com/googleapis/genai-toolbox/releases)
|
||||
corresponding to your OS and CPU architecture.
|
||||
{{< /notice >}}
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.2.1/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
1. Make the binary executable:
|
||||
|
||||
```bash
|
||||
chmod +x toolbox
|
||||
```
|
||||
|
||||
1. Write the following into a `tools.yaml` file. Be sure to update any fields
|
||||
such as `user`, `password`, or `database` that you may have customized in the
|
||||
previous step.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-pg-source:
|
||||
kind: postgres
|
||||
host: 127.0.0.1
|
||||
port: 5432
|
||||
database: toolbox_db
|
||||
user: toolbox_user
|
||||
password: my-password
|
||||
tools:
|
||||
search-hotels-by-name:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Search for hotels based on name.
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: The name of the hotel.
|
||||
statement: SELECT * FROM hotels WHERE name ILIKE '%' || $1 || '%';
|
||||
search-hotels-by-location:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Search for hotels based on location.
|
||||
parameters:
|
||||
- name: location
|
||||
type: string
|
||||
description: The location of the hotel.
|
||||
statement: SELECT * FROM hotels WHERE location ILIKE '%' || $1 || '%';
|
||||
book-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: >-
|
||||
Book a hotel by its ID. If the hotel is successfully booked, returns a NULL, raises an error if not.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to book.
|
||||
statement: UPDATE hotels SET booked = B'1' WHERE id = $1;
|
||||
update-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: >-
|
||||
Update a hotel's check-in and check-out dates by its ID. Returns a message
|
||||
indicating whether the hotel was successfully updated or not.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to update.
|
||||
- name: checkin_date
|
||||
type: string
|
||||
description: The new check-in date of the hotel.
|
||||
- name: checkout_date
|
||||
type: string
|
||||
description: The new check-out date of the hotel.
|
||||
statement: >-
|
||||
UPDATE hotels SET checkin_date = CAST($2 as date), checkout_date = CAST($3
|
||||
as date) WHERE id = $1;
|
||||
cancel-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Cancel a hotel by its ID.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to cancel.
|
||||
statement: UPDATE hotels SET booked = B'0' WHERE id = $1;
|
||||
```
|
||||
For more info on tools, check out the `Resources` section of the docs.
|
||||
|
||||
1. Run the Toolbox server, pointing to the `tools.yaml` file created earlier:
|
||||
|
||||
```bash
|
||||
./toolbox --tools_file "tools.yaml"
|
||||
```
|
||||
|
||||
## Step 3: Connect to MCP Inspector
|
||||
|
||||
1. Run the MCP Inspector:
|
||||
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector
|
||||
```
|
||||
|
||||
1. Type `y` when it asks to install the inspector package.
|
||||
|
||||
1. It should show the folo=lowing when the MCP Inspector is up and runnning:
|
||||
|
||||
```bash
|
||||
🔍 MCP Inspector is up and running at http://127.0.0.1:5173 🚀
|
||||
```
|
||||
|
||||
1. Open the above link in your browser.
|
||||
|
||||
1. For `Transport Type`, select `SSE`.
|
||||
|
||||
1. For `URL`, type in `http://127.0.0.1:5000/mcp/sse`.
|
||||
|
||||
1. Click Connect.
|
||||
|
||||

|
||||
|
||||
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
|
||||
|
||||

|
||||
|
||||
1. Test out your tools here!
|
||||
BIN
docs/en/getting-started/mcp_quickstart/inspector.png
Normal file
BIN
docs/en/getting-started/mcp_quickstart/inspector.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 22 KiB |
BIN
docs/en/getting-started/mcp_quickstart/inspector_tools.png
Normal file
BIN
docs/en/getting-started/mcp_quickstart/inspector_tools.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
76
docs/en/how-to/connect_via_mcp.md
Normal file
76
docs/en/how-to/connect_via_mcp.md
Normal file
@@ -0,0 +1,76 @@
|
||||
---
|
||||
title: "Connect via MCP Client"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to connect to Toolbox from a MCP Client.
|
||||
---
|
||||
|
||||
## Toolbox SDKs vs Model Context Protocol (MCP)
|
||||
Toolbox now supports connections via both the native Toolbox SDKs and via [Model Context Protocol (MCP)](https://modelcontextprotocol.io/). However, Toolbox as several features which are not supported in the MCP specification (such as Authenticated Parameters and Authorized invocation).
|
||||
|
||||
We recommend using the native SDKs over MCP clients to leverage these features. The native SDKs can be combined with MCP clients in many cases.
|
||||
|
||||
### Protocol Versions
|
||||
Toolbox currently supports the following versions of MCP specification:
|
||||
* [2024-11-05](https://spec.modelcontextprotocol.io/specification/2024-11-05/)
|
||||
|
||||
### Features Not Supported by MCP
|
||||
Toolbox has several features that are not yet supported in the MCP specification:
|
||||
* **AuthZ/AuthN:** There are no auth implementation in the `2024-11-05` specification. This includes:
|
||||
* [Authenticated Parameters](../resources/tools/_index.md#authenticated-parameters)
|
||||
* [Authorized Invocations](../resources/tools/_index.md#authorized-invocations)
|
||||
* **Toolsets**: MCP does not have the concept of toolset. Hence, all tools are automatically loaded when using Toolbox with MCP.
|
||||
* **Notifications:** Currently, editing Toolbox Tools requires a server restart. Clients should reload tools on disconnect to get the latest version.
|
||||
|
||||
|
||||
## Connecting to Toolbox with an MCP client
|
||||
### Before you begin
|
||||
|
||||
{{< notice note >}}
|
||||
MCP is only compatible with Toolbox version 0.3.0 and above.
|
||||
{{< /notice >}}
|
||||
|
||||
1. [Install](../getting-started/introduction/_index.md#installing-the-server) Toolbox version 0.3.0+.
|
||||
|
||||
1. Make sure you've set up and initialized your database.
|
||||
|
||||
1. [Set up](../getting-started/configure.md) your `tools.yaml` file.
|
||||
|
||||
### Connecting via HTTP
|
||||
Toolbox supports the HTTP transport protocol with and without SSE.
|
||||
|
||||
{{< tabpane text=true >}} {{% tab header="HTTP with SSE" lang="en" %}}
|
||||
Add the following configuration to your MCP client configuration:
|
||||
```bash
|
||||
{
|
||||
"mcpServers": {
|
||||
"toolbox": {
|
||||
"type": "sse",
|
||||
"url": "http://127.0.0.1:5000/mcp/sse",
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
{{% /tab %}} {{% tab header="HTTP POST" lang="en" %}}
|
||||
Connect to Toolbox HTTP POST via `http://127.0.0.1:5000/mcp`.
|
||||
{{% /tab %}} {{< /tabpane >}}
|
||||
|
||||
### Using the MCP Inspector with Toolbox
|
||||
|
||||
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testing and debugging Toolbox server.
|
||||
|
||||
1. [Run Toolbox](../getting-started/introduction/_index.md#running-the-server).
|
||||
|
||||
1. In a separate terminal, run Inspector directly through `npx`:
|
||||
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector
|
||||
```
|
||||
|
||||
1. For `Transport Type` dropdown menu, select `SSE`.
|
||||
|
||||
1. For `URL`, type in `http://127.0.0.1:5000/mcp/sse`.
|
||||
|
||||
1. Click the `Connect` button. Voila! You should be able to inspect your toolbox
|
||||
tools!
|
||||
@@ -24,8 +24,9 @@ import (
|
||||
)
|
||||
|
||||
func TestToolsetEndpoint(t *testing.T) {
|
||||
toolsMap, toolsets := setUpResources(t)
|
||||
ts, shutdown := setUpServer(t, toolsMap, toolsets)
|
||||
mockTools := []MockTool{tool1, tool2}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
ts, shutdown := setUpServer(t, "api", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
|
||||
// wantResponse is a struct for checks against test cases
|
||||
@@ -118,8 +119,9 @@ func TestToolsetEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToolGetEndpoint(t *testing.T) {
|
||||
toolsMap, toolsets := setUpResources(t)
|
||||
ts, shutdown := setUpServer(t, toolsMap, toolsets)
|
||||
mockTools := []MockTool{tool1, tool2}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
ts, shutdown := setUpServer(t, "api", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
|
||||
// wantResponse is a struct for checks against test cases
|
||||
|
||||
@@ -21,8 +21,10 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -38,6 +40,7 @@ type MockTool struct {
|
||||
Name string
|
||||
Description string
|
||||
Params []tools.Parameter
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke(tools.ParamValues) ([]any, error) {
|
||||
@@ -61,6 +64,29 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t MockTool) McpManifest() tools.McpManifest {
|
||||
properties := make(map[string]tools.ParameterMcpManifest)
|
||||
required := make([]string, 0)
|
||||
|
||||
for _, p := range t.Params {
|
||||
name := p.GetName()
|
||||
properties[name] = p.McpManifest()
|
||||
required = append(required, name)
|
||||
}
|
||||
|
||||
toolsSchema := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
}
|
||||
|
||||
return tools.McpManifest{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
InputSchema: toolsSchema,
|
||||
}
|
||||
}
|
||||
|
||||
var tool1 = MockTool{
|
||||
Name: "no_params",
|
||||
Params: []tools.Parameter{},
|
||||
@@ -74,15 +100,29 @@ var tool2 = MockTool{
|
||||
},
|
||||
}
|
||||
|
||||
var tool3 = MockTool{
|
||||
Name: "array_param",
|
||||
Description: "some description",
|
||||
Params: tools.Parameters{
|
||||
tools.NewArrayParameter("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
|
||||
},
|
||||
}
|
||||
|
||||
// setUpResources setups resources to test against
|
||||
func setUpResources(t *testing.T) (map[string]tools.Tool, map[string]tools.Toolset) {
|
||||
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
||||
func setUpResources(t *testing.T, mockTools []MockTool) (map[string]tools.Tool, map[string]tools.Toolset) {
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
var allTools []string
|
||||
for _, tool := range mockTools {
|
||||
tool.manifest = tool.Manifest()
|
||||
toolsMap[tool.Name] = tool
|
||||
allTools = append(allTools, tool.Name)
|
||||
}
|
||||
|
||||
toolsets := make(map[string]tools.Toolset)
|
||||
for name, l := range map[string][]string{
|
||||
"": {tool1.Name, tool2.Name},
|
||||
"tool1_only": {tool1.Name},
|
||||
"tool2_only": {tool2.Name},
|
||||
"": allTools,
|
||||
"tool1_only": {allTools[0]},
|
||||
"tool2_only": {allTools[1]},
|
||||
} {
|
||||
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
|
||||
m, err := tc.Initialize(fakeVersionString, toolsMap)
|
||||
@@ -95,7 +135,7 @@ func setUpResources(t *testing.T) (map[string]tools.Tool, map[string]tools.Tools
|
||||
}
|
||||
|
||||
// setUpServer create a new server with tools and toolsets that are given
|
||||
func setUpServer(t *testing.T, tools map[string]tools.Tool, toolsets map[string]tools.Toolset) (*httptest.Server, func()) {
|
||||
func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, toolsets map[string]tools.Toolset) (*httptest.Server, func()) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
||||
@@ -113,10 +153,26 @@ func setUpServer(t *testing.T, tools map[string]tools.Tool, toolsets map[string]
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, tools: tools, toolsets: toolsets}
|
||||
r, err := apiRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize router: %s", err)
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
|
||||
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: tools, toolsets: toolsets}
|
||||
var r chi.Router
|
||||
switch router {
|
||||
case "api":
|
||||
r, err = apiRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize api router: %s", err)
|
||||
}
|
||||
case "mcp":
|
||||
r, err = mcpRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize mcp router: %s", err)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unknown router")
|
||||
}
|
||||
ts := httptest.NewServer(r)
|
||||
shutdown := func() {
|
||||
|
||||
@@ -29,6 +29,8 @@ const (
|
||||
toolsetGetCountName = "toolbox.server.toolset.get.count"
|
||||
toolGetCountName = "toolbox.server.tool.get.count"
|
||||
toolInvokeCountName = "toolbox.server.tool.invoke.count"
|
||||
mcpSseCountName = "toolbox.server.mcp.sse.count"
|
||||
mcpPostCountName = "toolbox.server.mcp.post.count"
|
||||
)
|
||||
|
||||
// Instrumentation defines the telemetry instrumentation for toolbox
|
||||
@@ -38,6 +40,8 @@ type Instrumentation struct {
|
||||
ToolsetGet metric.Int64Counter
|
||||
ToolGet metric.Int64Counter
|
||||
ToolInvoke metric.Int64Counter
|
||||
McpSse metric.Int64Counter
|
||||
McpPost metric.Int64Counter
|
||||
}
|
||||
|
||||
func CreateTelemetryInstrumentation(versionString string) (*Instrumentation, error) {
|
||||
@@ -74,12 +78,32 @@ func CreateTelemetryInstrumentation(versionString string) (*Instrumentation, err
|
||||
return nil, fmt.Errorf("unable to create %s metric: %w", toolInvokeCountName, err)
|
||||
}
|
||||
|
||||
mcpSse, err := meter.Int64Counter(
|
||||
mcpSseCountName,
|
||||
metric.WithDescription("Number of MCP SSE connection requests."),
|
||||
metric.WithUnit("{connection}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create %s metric: %w", mcpSseCountName, err)
|
||||
}
|
||||
|
||||
mcpPost, err := meter.Int64Counter(
|
||||
mcpPostCountName,
|
||||
metric.WithDescription("Number of MCP Post calls."),
|
||||
metric.WithUnit("{call}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create %s metric: %w", mcpPostCountName, err)
|
||||
}
|
||||
|
||||
instrumentation := &Instrumentation{
|
||||
Tracer: tracer,
|
||||
meter: meter,
|
||||
ToolsetGet: toolsetGet,
|
||||
ToolGet: toolGet,
|
||||
ToolInvoke: toolInvoke,
|
||||
McpSse: mcpSse,
|
||||
McpPost: mcpPost,
|
||||
}
|
||||
return instrumentation, nil
|
||||
}
|
||||
|
||||
363
internal/server/mcp.go
Normal file
363
internal/server/mcp.go
Normal file
@@ -0,0 +1,363 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type sseSession struct {
|
||||
sessionId string
|
||||
writer http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
done chan struct{}
|
||||
eventQueue chan string
|
||||
}
|
||||
|
||||
// sseManager manages and control access to sse sessions
|
||||
type sseManager struct {
|
||||
mu sync.RWMutex
|
||||
sseSessions map[string]*sseSession
|
||||
}
|
||||
|
||||
func (m *sseManager) get(id string) (*sseSession, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
session, ok := m.sseSessions[id]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (m *sseManager) add(id string, session *sseSession) {
|
||||
m.mu.Lock()
|
||||
m.sseSessions[id] = session
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *sseManager) remove(id string) {
|
||||
m.mu.Lock()
|
||||
delete(m.sseSessions, id)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// mcpRouter creates a router that represents the routes under /mcp
|
||||
func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
r.Use(middleware.StripSlashes)
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// sseHandler handles sse initialization and message.
|
||||
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
sessionId := uuid.New().String()
|
||||
span.SetAttributes(attribute.String("session_id", sessionId))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
}
|
||||
span.End()
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
s.instrumentation.McpSse.Add(
|
||||
r.Context(),
|
||||
1,
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
}()
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
err = fmt.Errorf("unable to retrieve flusher for sse")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
}
|
||||
session := &sseSession{
|
||||
sessionId: sessionId,
|
||||
writer: w,
|
||||
flusher: flusher,
|
||||
done: make(chan struct{}),
|
||||
eventQueue: make(chan string, 100),
|
||||
}
|
||||
s.sseManager.add(sessionId, session)
|
||||
defer s.sseManager.remove(sessionId)
|
||||
|
||||
// send initial endpoint event
|
||||
messageEndpoint := fmt.Sprintf("http://%s/mcp?sessionId=%s", r.Host, sessionId)
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint))
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||
flusher.Flush()
|
||||
|
||||
clientClose := r.Context().Done()
|
||||
for {
|
||||
select {
|
||||
// Ensure that only a single responses are written at once
|
||||
case event := <-session.eventQueue:
|
||||
fmt.Fprint(w, event)
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("sending event: %s", event))
|
||||
flusher.Flush()
|
||||
// channel for client disconnection
|
||||
case <-clientClose:
|
||||
close(session.done)
|
||||
s.logger.DebugContext(ctx, "client disconnected")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mcpHandler handles all mcp messages.
|
||||
func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
var id, toolName, method string
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
}
|
||||
span.End()
|
||||
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
s.instrumentation.McpPost.Add(
|
||||
r.Context(),
|
||||
1,
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)),
|
||||
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.method", method)),
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
}()
|
||||
|
||||
// Read and returns a body from io.Reader
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id = uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
}
|
||||
|
||||
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
|
||||
var baseMessage struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Id mcp.RequestId `json:"id,omitempty"`
|
||||
}
|
||||
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if method is present
|
||||
if baseMessage.Method == "" {
|
||||
err = fmt.Errorf("method not found")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check for JSON-RPC 2.0
|
||||
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
|
||||
err = fmt.Errorf("invalid json-rpc version")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if message is a notification
|
||||
if baseMessage.Id == nil {
|
||||
id = ""
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err = json.Unmarshal(body, ¬ification); err != nil {
|
||||
err = fmt.Errorf("invalid notification request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
}
|
||||
// Notifications do not expect a response
|
||||
// Toolbox doesn't do anything with notifications yet
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
id = fmt.Sprintf("%s", baseMessage.Id)
|
||||
method = baseMessage.Method
|
||||
|
||||
var res mcp.JSONRPCMessage
|
||||
switch baseMessage.Method {
|
||||
case "initialize":
|
||||
var req mcp.InitializeRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp initialize request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
result := mcp.Initialize(s.version)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
case "tools/list":
|
||||
var req mcp.ListToolsRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
toolset, ok := s.toolsets[""]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
result := mcp.ToolsList(toolset)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
case "tools/call":
|
||||
var req mcp.CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
toolName = req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
var data map[string]any
|
||||
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
result := mcp.ToolCall(tool, params)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("invalid method %s", baseMessage.Method)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil)
|
||||
}
|
||||
|
||||
// retrieve sse session
|
||||
sessionId := r.URL.Query().Get("sessionId")
|
||||
session, ok := s.sseManager.get(sessionId)
|
||||
if !ok {
|
||||
s.logger.DebugContext(ctx, "sse session not available")
|
||||
} else {
|
||||
// queue sse event
|
||||
eventData, _ := json.Marshal(res)
|
||||
select {
|
||||
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
|
||||
s.logger.DebugContext(ctx, "event queue successful")
|
||||
case <-session.done:
|
||||
s.logger.DebugContext(ctx, "session is close")
|
||||
default:
|
||||
s.logger.DebugContext(ctx, "unable to add to event queue")
|
||||
}
|
||||
}
|
||||
|
||||
// send HTTP response
|
||||
render.JSON(w, r, res)
|
||||
}
|
||||
|
||||
// newJSONRPCError is the response sent back when an error has been encountered in mcp.
|
||||
func newJSONRPCError(id mcp.RequestId, code int, message string, data any) mcp.JSONRPCError {
|
||||
return mcp.JSONRPCError{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Error: mcp.McpError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
}
|
||||
74
internal/server/mcp/method.go
Normal file
74
internal/server/mcp/method.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
func Initialize(version string) InitializeResult {
|
||||
toolsListChanged := false
|
||||
result := InitializeResult{
|
||||
ProtocolVersion: LATEST_PROTOCOL_VERSION,
|
||||
Capabilities: ServerCapabilities{
|
||||
Tools: &ListChanged{
|
||||
ListChanged: &toolsListChanged,
|
||||
},
|
||||
},
|
||||
ServerInfo: Implementation{
|
||||
Name: SERVER_NAME,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToolsList return a ListToolsResult
|
||||
func ToolsList(toolset tools.Toolset) ListToolsResult {
|
||||
mcpManifest := toolset.McpManifest
|
||||
|
||||
result := ListToolsResult{
|
||||
Tools: mcpManifest,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToolCall runs tool invocation and return a CallToolResult
|
||||
func ToolCall(tool tools.Tool, params tools.ParamValues) CallToolResult {
|
||||
res, err := tool.Invoke(params)
|
||||
if err != nil {
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return CallToolResult{Content: []TextContent{text}, IsError: true}
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range res {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||
} else {
|
||||
text.Text = string(dM)
|
||||
}
|
||||
content = append(content, text)
|
||||
}
|
||||
return CallToolResult{Content: content}
|
||||
}
|
||||
295
internal/server/mcp/types.go
Normal file
295
internal/server/mcp/types.go
Normal file
@@ -0,0 +1,295 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
const SERVER_NAME = "Toolbox"
|
||||
|
||||
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
|
||||
const LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
|
||||
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.
|
||||
const JSONRPC_VERSION = "2.0"
|
||||
|
||||
// Standard JSON-RPC error codes
|
||||
const (
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
)
|
||||
|
||||
// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError.
|
||||
type JSONRPCMessage interface{}
|
||||
|
||||
// ProgressToken is used to associate progress notifications with the original request.
|
||||
type ProgressToken interface{}
|
||||
|
||||
// Request represents a bidirectional message with method and parameters expecting a response.
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta struct {
|
||||
// If specified, the caller is requesting out-of-band progress
|
||||
// notifications for this request (as represented by
|
||||
// notifications/progress). The value of this parameter is an
|
||||
// opaque token that will be attached to any subsequent
|
||||
// notifications. The receiver is not obligated to provide these
|
||||
// notifications.
|
||||
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Notification is a one-way message requiring no response.
|
||||
type Notification struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Result represents a response for the request query.
|
||||
type Result struct {
|
||||
// This result property is reserved by the protocol to allow clients and
|
||||
// servers to attach additional metadata to their responses.
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// RequestId is a uniquely identifying ID for a request in JSON-RPC.
|
||||
// It can be any JSON-serializable value, typically a number or string.
|
||||
type RequestId interface{}
|
||||
|
||||
// JSONRPCRequest represents a request that expects a response.
|
||||
type JSONRPCRequest struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Request
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCNotification represents a notification which does not expect a response.
|
||||
type JSONRPCNotification struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Notification
|
||||
}
|
||||
|
||||
// JSONRPCResponse represents a successful (non-error) response to a request.
|
||||
type JSONRPCResponse struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Result interface{} `json:"result"`
|
||||
}
|
||||
|
||||
// McpError represents the error content.
|
||||
type McpError struct {
|
||||
// The error type that occurred.
|
||||
Code int `json:"code"`
|
||||
// A short description of the error. The message SHOULD be limited
|
||||
// to a concise single sentence.
|
||||
Message string `json:"message"`
|
||||
// Additional information about the error. The value of this member
|
||||
// is defined by the sender (e.g. detailed error information, nested errors etc.).
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCError represents a non-successful (error) response to a request.
|
||||
type JSONRPCError struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Error McpError `json:"error"`
|
||||
}
|
||||
|
||||
/* Empty result */
|
||||
|
||||
// EmptyResult represents a response that indicates success but carries no data.
|
||||
type EmptyResult Result
|
||||
|
||||
/* Initialization */
|
||||
|
||||
// Params to define MCP Client during initialize request.
|
||||
type InitializeParams struct {
|
||||
// The latest version of the Model Context Protocol that the client supports.
|
||||
// The client MAY decide to support older versions as well.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ClientCapabilities `json:"capabilities"`
|
||||
ClientInfo Implementation `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// InitializeRequest is sent from the client to the server when it first
|
||||
// connects, asking it to begin initialization.
|
||||
type InitializeRequest struct {
|
||||
Request
|
||||
Params InitializeParams `json:"params"`
|
||||
}
|
||||
|
||||
// InitializeResult is sent after receiving an initialize request from the
|
||||
// client.
|
||||
type InitializeResult struct {
|
||||
Result
|
||||
// The version of the Model Context Protocol that the server wants to use.
|
||||
// This may not match the version that the client requested. If the client cannot
|
||||
// support this version, it MUST disconnect.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo Implementation `json:"serverInfo"`
|
||||
// Instructions describing how to use the server and its features.
|
||||
//
|
||||
// This can be used by clients to improve the LLM's understanding of
|
||||
// available tools, resources, etc. It can be thought of like a "hint" to the model.
|
||||
// For example, this information MAY be added to the system prompt.
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
}
|
||||
|
||||
// InitializedNotification is sent from the client to the server after
|
||||
// initialization has finished.
|
||||
type InitializedNotification struct {
|
||||
Notification
|
||||
}
|
||||
|
||||
// ListChange represents whether the server supports notification for changes to the capabilities.
|
||||
type ListChanged struct {
|
||||
ListChanged *bool `json:"listChanged,omitempty"`
|
||||
}
|
||||
|
||||
// ClientCapabilities represents capabilities a client may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// client can define its own, additional capabilities.
|
||||
type ClientCapabilities struct {
|
||||
// Experimental, non-standard capabilities that the client supports.
|
||||
Experimental map[string]interface{} `json:"experimental,omitempty"`
|
||||
// Present if the client supports listing roots.
|
||||
Roots *ListChanged `json:"roots,omitempty"`
|
||||
// Present if the client supports sampling from an LLM.
|
||||
Sampling struct{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerCapabilities represents capabilities that a server may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// server can define its own, additional capabilities.
|
||||
type ServerCapabilities struct {
|
||||
Tools *ListChanged `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// Implementation describes the name and version of an MCP implementation.
|
||||
type Implementation struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
/* Pagination */
|
||||
|
||||
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||
type Cursor string
|
||||
|
||||
type PaginatedRequest struct {
|
||||
Request
|
||||
Params struct {
|
||||
// An opaque token representing the current pagination position.
|
||||
// If provided, the server should return results starting after this cursor.
|
||||
Cursor Cursor `json:"cursor,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type PaginatedResult struct {
|
||||
Result
|
||||
// An opaque token representing the pagination position after the last returned result.
|
||||
// If present, there may be more results available.
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
/* Tools */
|
||||
|
||||
// Sent from the client to request a list of tools the server has.
|
||||
type ListToolsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a tools/list request from the client.
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []tools.McpManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// Used by the client to invoke a tool provided by the server.
|
||||
type CallToolRequest struct {
|
||||
Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// The sender or recipient of messages and data in a conversation.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Base for objects that include optional annotations for the client.
|
||||
// The client can use annotations to inform how objects are used or displayed
|
||||
type Annotated struct {
|
||||
Annotations *struct {
|
||||
// Describes who the intended customer of this object or data is.
|
||||
// It can include multiple entries to indicate content useful for multiple
|
||||
// audiences (e.g., `["user", "assistant"]`).
|
||||
Audience []Role `json:"audience,omitempty"`
|
||||
// Describes how important this data is for operating the server.
|
||||
//
|
||||
// A value of 1 means "most important," and indicates that the data is
|
||||
// effectively required, while 0 means "least important," and indicates that
|
||||
// the data is entirely optional.
|
||||
//
|
||||
// @TJS-type number
|
||||
// @minimum 0
|
||||
// @maximum 1
|
||||
Priority float64 `json:"priority,omitempty"`
|
||||
} `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// TextContent represents text provided to or from an LLM.
|
||||
type TextContent struct {
|
||||
Annotated
|
||||
Type string `json:"type"`
|
||||
// The text content of the message.
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// The server's response to a tool call.
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||
// For Toolbox, we will only be sending TextContent
|
||||
Content []TextContent `json:"content"`
|
||||
// Whether the tool call ended in an error.
|
||||
// If not set, this is assumed to be false (the call was successful).
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
258
internal/server/mcp_test.go
Normal file
258
internal/server/mcp_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
)
|
||||
|
||||
const jsonrpcVersion = "2.0"
|
||||
const protocolVersion = "2024-11-05"
|
||||
const serverName = "Toolbox"
|
||||
|
||||
var tool1InputSchema = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
"required": []any{},
|
||||
}
|
||||
|
||||
var tool2InputSchema = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"param1": map[string]any{"type": "integer", "description": "This is the first parameter."},
|
||||
"param2": map[string]any{"type": "integer", "description": "This is the second parameter."},
|
||||
},
|
||||
"required": []any{"param1", "param2"},
|
||||
}
|
||||
|
||||
var tool3InputSchema = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"my_array": map[string]any{
|
||||
"type": "array",
|
||||
"description": "this param is an array of strings",
|
||||
"items": map[string]any{"type": "string", "description": "string item"},
|
||||
},
|
||||
},
|
||||
"required": []any{"my_array"},
|
||||
}
|
||||
|
||||
func TestMcpEndpoint(t *testing.T) {
|
||||
mockTools := []MockTool{tool1, tool2, tool3}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
ts, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
isErr bool
|
||||
body mcp.JSONRPCRequest
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "initialize",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "mcp-initialize",
|
||||
Request: mcp.Request{
|
||||
Method: "initialize",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"result": map[string]any{
|
||||
"protocolVersion": protocolVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic notification",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Request: mcp.Request{
|
||||
Method: "notification",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list",
|
||||
"result": map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "no_params",
|
||||
"inputSchema": tool1InputSchema,
|
||||
},
|
||||
map[string]any{
|
||||
"name": "some_params",
|
||||
"inputSchema": tool2InputSchema,
|
||||
},
|
||||
map[string]any{
|
||||
"name": "array_param",
|
||||
"description": "some description",
|
||||
"inputSchema": tool3InputSchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "missing-method",
|
||||
Request: mcp.Request{},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "missing-method",
|
||||
"error": map[string]any{
|
||||
"code": -32601.0,
|
||||
"message": "method not found",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "invalid-method",
|
||||
Request: mcp.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "invalid-method",
|
||||
"error": map[string]any{
|
||||
"code": -32601.0,
|
||||
"message": "invalid method foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid jsonrpc version",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "1.0",
|
||||
Id: "invalid-jsonrpc-version",
|
||||
Request: mcp.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "invalid-jsonrpc-version",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "invalid json-rpc version",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqMarshal, err := json.Marshal(tc.body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(reqMarshal))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
// Notifications don't expect a response.
|
||||
if tc.want != nil {
|
||||
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unexpected error unmarshalling body: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected response: got %+v, want %+v", got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSseEndpoint(t *testing.T) {
|
||||
ts, shutdown := setUpServer(t, "mcp", nil, nil)
|
||||
defer shutdown()
|
||||
|
||||
contentType := "text/event-stream"
|
||||
cacheControl := "no-cache"
|
||||
connection := "keep-alive"
|
||||
accessControlAllowOrigin := "*"
|
||||
wantEvent := "event: endpoint"
|
||||
|
||||
t.Run("test sse endpoint", func(t *testing.T) {
|
||||
resp, err := http.Get(ts.URL + "/sse")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType)
|
||||
}
|
||||
if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl)
|
||||
}
|
||||
if gotConnection := resp.Header.Get("Connection"); gotConnection != connection {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection)
|
||||
}
|
||||
if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := resp.Body.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read response: %s", err)
|
||||
}
|
||||
endpointEvent := string(buffer[:n])
|
||||
if !strings.Contains(endpointEvent, wantEvent) {
|
||||
t.Fatalf("unexpected event: got %s", endpointEvent)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -42,6 +43,7 @@ type Server struct {
|
||||
root chi.Router
|
||||
logger log.Logger
|
||||
instrumentation *Instrumentation
|
||||
sseManager *sseManager
|
||||
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
@@ -203,12 +205,18 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
|
||||
srv := &http.Server{Addr: addr, Handler: r}
|
||||
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
srv: srv,
|
||||
root: r,
|
||||
logger: l,
|
||||
instrumentation: instrumentation,
|
||||
sseManager: sseManager,
|
||||
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
@@ -221,6 +229,11 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
return nil, err
|
||||
}
|
||||
r.Mount("/api", apiR)
|
||||
mcpR, err := mcpRouter(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Mount("/mcp", mcpR)
|
||||
// default endpoint for validating server is running
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("🧰 Hello, World! 🧰"))
|
||||
|
||||
@@ -65,6 +65,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -75,6 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
IsQuery: cfg.IsQuery,
|
||||
Timeout: cfg.Timeout,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -92,6 +99,7 @@ type Tool struct {
|
||||
Timeout string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
@@ -127,6 +135,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
// Create a slice for all parameters
|
||||
allParameters := slices.Concat(cfg.BodyParams, cfg.HeaderParams, cfg.QueryParams)
|
||||
|
||||
// Create parameter manifest
|
||||
// Create parameter MCP manifest
|
||||
paramManifest := slices.Concat(
|
||||
cfg.QueryParams.Manifest(),
|
||||
cfg.BodyParams.Manifest(),
|
||||
@@ -101,6 +101,36 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
queryMcpManifest := cfg.QueryParams.McpManifest()
|
||||
bodyMcpManifest := cfg.BodyParams.McpManifest()
|
||||
headerMcpManifest := cfg.HeaderParams.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
queryMcpManifest.Required,
|
||||
bodyMcpManifest.Required,
|
||||
headerMcpManifest.Required,
|
||||
)
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range queryMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range bodyMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range headerMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
// Verify there are no duplicate parameter names
|
||||
seenNames := make(map[string]bool)
|
||||
for _, param := range paramManifest {
|
||||
@@ -110,6 +140,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
seenNames[param.Name] = true
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -125,6 +161,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -146,8 +183,9 @@ type Tool struct {
|
||||
HeaderParams tools.Parameters `yaml:"headerParams"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// helper function to convert a parameter to JSON formatted string.
|
||||
@@ -271,6 +309,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -68,6 +68,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -77,6 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.MSSQLDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -90,9 +97,10 @@ type Tool struct {
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Db *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
Db *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
@@ -157,6 +165,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -67,6 +67,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -76,6 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.MySQLPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -89,9 +96,10 @@ type Tool struct {
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
Pool *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
@@ -152,6 +160,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -66,15 +66,22 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Driver: s.Neo4jDriver(),
|
||||
Database: s.Neo4jDatabase(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Driver: s.Neo4jDriver(),
|
||||
Database: s.Neo4jDatabase(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -88,10 +95,11 @@ type Tool struct {
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Driver neo4j.DriverWithContext
|
||||
Database string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
Driver neo4j.DriverWithContext
|
||||
Database string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
@@ -127,6 +135,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -154,6 +154,14 @@ type Parameter interface {
|
||||
GetAuthServices() []ParamAuthService
|
||||
Parse(any) (any, error)
|
||||
Manifest() ParameterManifest
|
||||
McpManifest() ParameterMcpManifest
|
||||
}
|
||||
|
||||
// McpToolsSchema is the representation of input schema for McpManifest.
|
||||
type McpToolsSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ParameterMcpManifest `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
|
||||
// Parameters is a type used to allow unmarshal a list of parameters
|
||||
@@ -265,6 +273,24 @@ func (ps Parameters) Manifest() []ParameterManifest {
|
||||
return rtn
|
||||
}
|
||||
|
||||
func (ps Parameters) McpManifest() McpToolsSchema {
|
||||
properties := make(map[string]ParameterMcpManifest)
|
||||
required := make([]string, 0)
|
||||
|
||||
for _, p := range ps {
|
||||
name := p.GetName()
|
||||
properties[name] = p.McpManifest()
|
||||
// all parameters are added to the required field
|
||||
required = append(required, name)
|
||||
}
|
||||
|
||||
return McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
}
|
||||
}
|
||||
|
||||
// ParameterManifest represents parameters when served as part of a ToolManifest.
|
||||
type ParameterManifest struct {
|
||||
Name string `json:"name"`
|
||||
@@ -274,6 +300,13 @@ type ParameterManifest struct {
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
}
|
||||
|
||||
// ParameterMcpManifest represents properties when served as part of a ToolMcpManifest.
|
||||
type ParameterMcpManifest struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Items *ParameterMcpManifest `json:"items,omitempty"`
|
||||
}
|
||||
|
||||
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
|
||||
type CommonParameter struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
@@ -308,6 +341,14 @@ func (p *CommonParameter) Manifest() ParameterManifest {
|
||||
}
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest for the Parameter.
|
||||
func (p *CommonParameter) McpManifest() ParameterMcpManifest {
|
||||
return ParameterMcpManifest{
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseTypeError is a custom error for incorrectly typed Parameters.
|
||||
type ParseTypeError struct {
|
||||
Name string
|
||||
@@ -611,3 +652,18 @@ func (p *ArrayParameter) Manifest() ParameterManifest {
|
||||
Items: &items,
|
||||
}
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest for the ArrayParameter.
|
||||
func (p *ArrayParameter) McpManifest() ParameterMcpManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
items := p.Items.McpManifest()
|
||||
return ParameterMcpManifest{
|
||||
Type: p.Type,
|
||||
Description: p.Desc,
|
||||
Items: &items,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -847,6 +847,52 @@ func TestParamManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamMcpManifest(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in tools.Parameter
|
||||
want tools.ParameterMcpManifest
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
in: tools.NewStringParameter("foo-string", "bar"),
|
||||
want: tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
in: tools.NewIntParameter("foo-int", "bar"),
|
||||
want: tools.ParameterMcpManifest{Type: "integer", Description: "bar"},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
in: tools.NewFloatParameter("foo-float", "bar"),
|
||||
want: tools.ParameterMcpManifest{Type: "float", Description: "bar"},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
in: tools.NewBooleanParameter("foo-bool", "bar"),
|
||||
want: tools.ParameterMcpManifest{Type: "boolean", Description: "bar"},
|
||||
},
|
||||
{
|
||||
name: "array",
|
||||
in: tools.NewArrayParameter("foo-array", "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
want: tools.ParameterMcpManifest{
|
||||
Type: "array",
|
||||
Description: "bar",
|
||||
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.in.McpManifest()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected manifest: got %+v, want %+v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParametersUnmarshal(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
@@ -69,6 +69,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -78,6 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -91,9 +98,10 @@ type Tool struct {
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
|
||||
@@ -129,6 +137,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -68,6 +68,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
@@ -78,6 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Client: s.SpannerClient(),
|
||||
dialect: s.DatabaseDialect(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -91,10 +98,11 @@ type Tool struct {
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *spanner.Client
|
||||
dialect string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
Client *spanner.Client
|
||||
dialect string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func getMapParams(params tools.ParamValues, dialect string) (map[string]interface{}, error) {
|
||||
@@ -157,6 +165,10 @@ func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type Tool interface {
|
||||
Invoke(ParamValues) ([]any, error)
|
||||
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
||||
Manifest() Manifest
|
||||
McpManifest() McpManifest
|
||||
Authorized([]string) bool
|
||||
}
|
||||
|
||||
@@ -38,6 +39,16 @@ type Manifest struct {
|
||||
Parameters []ParameterManifest `json:"parameters"`
|
||||
}
|
||||
|
||||
// Definition for a tool the MCP client can call.
|
||||
type McpManifest struct {
|
||||
// The name of the tool.
|
||||
Name string `json:"name"`
|
||||
// A human-readable description of the tool.
|
||||
Description string `json:"description,omitempty"`
|
||||
// A JSON Schema object defining the expected parameters for the tool.
|
||||
InputSchema McpToolsSchema `json:"inputSchema,omitempty"`
|
||||
}
|
||||
|
||||
// Helper function that returns if a tool invocation request is authorized
|
||||
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
|
||||
if len(authRequiredSources) == 0 {
|
||||
|
||||
@@ -24,9 +24,10 @@ type ToolsetConfig struct {
|
||||
}
|
||||
|
||||
type Toolset struct {
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
Manifest ToolsetManifest `yaml:",inline"`
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
Manifest ToolsetManifest `yaml:",inline"`
|
||||
McpManifest []McpManifest `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetManifest struct {
|
||||
@@ -54,6 +55,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool
|
||||
}
|
||||
toolset.Tools = append(toolset.Tools, &tool)
|
||||
toolset.Manifest.ToolsManifest[toolName] = tool.Manifest()
|
||||
toolset.McpManifest = append(toolset.McpManifest, tool.McpManifest())
|
||||
}
|
||||
|
||||
return toolset, nil
|
||||
|
||||
@@ -163,7 +163,9 @@ func TestAlloyDBToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"?column?\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -150,7 +150,9 @@ func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -145,7 +145,9 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"1\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -149,7 +149,9 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"?column?\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -93,6 +93,12 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
|
||||
"my-google-auth",
|
||||
},
|
||||
},
|
||||
"my-fail-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test statement with incorrect syntax.",
|
||||
"statement": "SELEC 1;",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -122,5 +122,7 @@ func TestMsSQLToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
@@ -121,5 +121,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"1\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
@@ -121,5 +121,7 @@ func TestPostgres(t *testing.T) {
|
||||
RunToolGetTest(t)
|
||||
|
||||
select_1_want := "[{\"?column?\":1}]"
|
||||
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
@@ -304,3 +305,136 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
|
||||
func RunMCPToolCallMethod(t *testing.T, fail_invocation_want string) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody mcp.JSONRPCRequest
|
||||
requestHeader map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "MCP Invoke my-param-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "my-param-tool",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-param-tool",
|
||||
"arguments": map[string]any{
|
||||
"id": int(3),
|
||||
"name": "Alice",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke invalid tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invalid-tool",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "foo",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-without-parameter",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-tool",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"invalid tool name: tool with name \"my-tool\" does not exist"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-insufficient-parameter",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-tool",
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"invalid tool name: tool with name \"my-tool\" does not exist"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-fail-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-fail-tool",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-fail-tool",
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
want: fail_invocation_want,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqMarshal, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of request body")
|
||||
}
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read request body: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
got := string(bytes.TrimSpace(respBody))
|
||||
|
||||
if got != tc.want {
|
||||
fmt.Printf("res is %s\n\n", got)
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user