feat: adding support for Model Context Protocol (MCP). (#396)

Adding Toolbox support for MCP. Toolbox can now be run as an MCP server.

Fixes #312.

---------

Co-authored-by: Jack Wotherspoon <jackwoth@google.com>
Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
Yuan
2025-04-04 11:24:34 -07:00
committed by GitHub
parent 55dff38db2
commit a7d1d4eb2a
36 changed files with 1833 additions and 61 deletions

View File

@@ -80,9 +80,5 @@ with clients in their preferred language. As Gen AI matures, we want developers
## Is Toolbox compatible with Model Context Protocol (MCP)?
Toolbox currently uses it's own custom protocol for server-client communication.
[Anthropic's Model Context Protocol (MCP)](https://modelcontextprotocol.io/)
launched towards the end of Toolbox's development, and is currently missing
functionality to support some of our features. We're currently exploring how
best to bring Toolbox's functionality to the wider MCP ecosystem.
Yes! Toolbox is compatible with [Anthropic's Model Context Protocol (MCP)](https://modelcontextprotocol.io/). Please checkout [Connect via MCP](../how-to/connect_via_mcp.md) on how to
connect to Toolbox with an MCP client.

View File

@@ -81,18 +81,22 @@ A metric is a measurement of a service captured at runtime. The collected data
can be used to provide important insights into the service. Toolbox provides the
following custom metrics:
| **Metric Name** | **Description** |
|------------------------------------|-------------------------------------------------------|
| `toolbox.server.toolset.get.count` | Counts the number of toolset manifest requests served |
| `toolbox.server.tool.get.count` | Counts the number of tool manifest requests served |
| `toolbox.server.tool.get.invoke` | Counts the number of tool invocation requests served |
| **Metric Name** | **Description** |
|------------------------------------|---------------------------------------------------------|
| `toolbox.server.toolset.get.count` | Counts the number of toolset manifest requests served |
| `toolbox.server.tool.get.count` | Counts the number of tool manifest requests served |
| `toolbox.server.tool.get.invoke` | Counts the number of tool invocation requests served |
| `toolbox.server.mcp.sse.count` | Counts the number of mcp sse connection requests served |
| `toolbox.server.mcp.post.count` | Counts the number of mcp post requests served |
All custom metrics have the following attributes/labels:
| **Metric Attributes** | **Description** |
|-----------------------|-----------------------------------------------------------|
| `toolbox.name` | Name of the toolset or tool, if applicable. |
| `toolbox.status` | Operation status code, for example: `success`, `failure`. |
| **Metric Attributes** | **Description** |
|----------------------------|-----------------------------------------------------------|
| `toolbox.name` | Name of the toolset or tool, if applicable. |
| `toolbox.operation.status` | Operation status code, for example: `success`, `failure`. |
| `toolbox.sse.sessionId` | Session id for sse connection, if applicable. |
| `toolbox.method` | Method of JSON-RPC request, if applicable. |
### Traces

View File

@@ -1,7 +1,7 @@
---
title: "Configuration"
type: docs
weight: 3
weight: 4
description: How to configure Toolbox's tools.yaml file.
---

View File

@@ -1,5 +1,5 @@
---
title: "Quickstart"
title: "Quickstart (Local)"
type: docs
weight: 2
description: >

View File

@@ -0,0 +1,228 @@
---
title: "Quickstart (MCP)"
type: docs
weight: 3
description: >
How to get started running Toolbox locally with MCP Inspector.
---
## Overview
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol
that standardizes how applications provide context to LLMs. Check out this page
on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
## Step 1: Set up your database
In this section, we will create a database, insert some data that needs to be
access by our agent, and create a database user for Toolbox to connect with.
1. Connect to postgres using the `psql` command:
```bash
psql -h 127.0.0.1 -U postgres
```
Here, `postgres` denotes the default postgres superuser.
1. Create a new database and a new user:
{{< notice tip >}}
For a real application, it's best to follow the principle of least permission
and only grant the privileges your application needs.
{{< /notice >}}
```sql
CREATE USER toolbox_user WITH PASSWORD 'my-password';
CREATE DATABASE toolbox_db;
GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user;
ALTER DATABASE toolbox_db OWNER TO toolbox_user;
```
1. End the database session:
```bash
\q
```
1. Connect to your database with your new user:
```bash
psql -h 127.0.0.1 -U toolbox_user -d toolbox_db
```
1. Create a table using the following command:
```sql
CREATE TABLE hotels(
id INTEGER NOT NULL PRIMARY KEY,
name VARCHAR NOT NULL,
location VARCHAR NOT NULL,
price_tier VARCHAR NOT NULL,
checkin_date DATE NOT NULL,
checkout_date DATE NOT NULL,
booked BIT NOT NULL
);
```
1. Insert data into the table.
```sql
INSERT INTO hotels(id, name, location, price_tier, checkin_date, checkout_date, booked)
VALUES
(1, 'Hilton Basel', 'Basel', 'Luxury', '2024-04-22', '2024-04-20', B'0'),
(2, 'Marriott Zurich', 'Zurich', 'Upscale', '2024-04-14', '2024-04-21', B'0'),
(3, 'Hyatt Regency Basel', 'Basel', 'Upper Upscale', '2024-04-02', '2024-04-20', B'0'),
(4, 'Radisson Blu Lucerne', 'Lucerne', 'Midscale', '2024-04-24', '2024-04-05', B'0'),
(5, 'Best Western Bern', 'Bern', 'Upper Midscale', '2024-04-23', '2024-04-01', B'0'),
(6, 'InterContinental Geneva', 'Geneva', 'Luxury', '2024-04-23', '2024-04-28', B'0'),
(7, 'Sheraton Zurich', 'Zurich', 'Upper Upscale', '2024-04-27', '2024-04-02', B'0'),
(8, 'Holiday Inn Basel', 'Basel', 'Upper Midscale', '2024-04-24', '2024-04-09', B'0'),
(9, 'Courtyard Zurich', 'Zurich', 'Upscale', '2024-04-03', '2024-04-13', B'0'),
(10, 'Comfort Inn Bern', 'Bern', 'Midscale', '2024-04-04', '2024-04-16', B'0');
```
1. End the database session:
```bash
\q
```
## Step 2: Install and configure Toolbox
In this section, we will download Toolbox, configure our tools in a
`tools.yaml`, and then run the Toolbox server.
1. Download the latest version of Toolbox as a binary:
{{< notice tip >}}
Select the
[correct binary](https://github.com/googleapis/genai-toolbox/releases)
corresponding to your OS and CPU architecture.
{{< /notice >}}
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.2.1/$OS/toolbox
```
<!-- {x-release-please-end} -->
1. Make the binary executable:
```bash
chmod +x toolbox
```
1. Write the following into a `tools.yaml` file. Be sure to update any fields
such as `user`, `password`, or `database` that you may have customized in the
previous step.
```yaml
sources:
my-pg-source:
kind: postgres
host: 127.0.0.1
port: 5432
database: toolbox_db
user: toolbox_user
password: my-password
tools:
search-hotels-by-name:
kind: postgres-sql
source: my-pg-source
description: Search for hotels based on name.
parameters:
- name: name
type: string
description: The name of the hotel.
statement: SELECT * FROM hotels WHERE name ILIKE '%' || $1 || '%';
search-hotels-by-location:
kind: postgres-sql
source: my-pg-source
description: Search for hotels based on location.
parameters:
- name: location
type: string
description: The location of the hotel.
statement: SELECT * FROM hotels WHERE location ILIKE '%' || $1 || '%';
book-hotel:
kind: postgres-sql
source: my-pg-source
description: >-
Book a hotel by its ID. If the hotel is successfully booked, returns a NULL, raises an error if not.
parameters:
- name: hotel_id
type: string
description: The ID of the hotel to book.
statement: UPDATE hotels SET booked = B'1' WHERE id = $1;
update-hotel:
kind: postgres-sql
source: my-pg-source
description: >-
Update a hotel's check-in and check-out dates by its ID. Returns a message
indicating whether the hotel was successfully updated or not.
parameters:
- name: hotel_id
type: string
description: The ID of the hotel to update.
- name: checkin_date
type: string
description: The new check-in date of the hotel.
- name: checkout_date
type: string
description: The new check-out date of the hotel.
statement: >-
UPDATE hotels SET checkin_date = CAST($2 as date), checkout_date = CAST($3
as date) WHERE id = $1;
cancel-hotel:
kind: postgres-sql
source: my-pg-source
description: Cancel a hotel by its ID.
parameters:
- name: hotel_id
type: string
description: The ID of the hotel to cancel.
statement: UPDATE hotels SET booked = B'0' WHERE id = $1;
```
For more info on tools, check out the `Resources` section of the docs.
1. Run the Toolbox server, pointing to the `tools.yaml` file created earlier:
```bash
./toolbox --tools_file "tools.yaml"
```
## Step 3: Connect to MCP Inspector
1. Run the MCP Inspector:
```bash
npx @modelcontextprotocol/inspector
```
1. Type `y` when it asks to install the inspector package.
1. It should show the folo=lowing when the MCP Inspector is up and runnning:
```bash
🔍 MCP Inspector is up and running at http://127.0.0.1:5173 🚀
```
1. Open the above link in your browser.
1. For `Transport Type`, select `SSE`.
1. For `URL`, type in `http://127.0.0.1:5000/mcp/sse`.
1. Click Connect.
![inspector](./inspector.png)
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
![inspector_tools](./inspector_tools.png)
1. Test out your tools here!

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 24 KiB

View File

@@ -0,0 +1,76 @@
---
title: "Connect via MCP Client"
type: docs
weight: 1
description: >
How to connect to Toolbox from a MCP Client.
---
## Toolbox SDKs vs Model Context Protocol (MCP)
Toolbox now supports connections via both the native Toolbox SDKs and via [Model Context Protocol (MCP)](https://modelcontextprotocol.io/). However, Toolbox as several features which are not supported in the MCP specification (such as Authenticated Parameters and Authorized invocation).
We recommend using the native SDKs over MCP clients to leverage these features. The native SDKs can be combined with MCP clients in many cases.
### Protocol Versions
Toolbox currently supports the following versions of MCP specification:
* [2024-11-05](https://spec.modelcontextprotocol.io/specification/2024-11-05/)
### Features Not Supported by MCP
Toolbox has several features that are not yet supported in the MCP specification:
* **AuthZ/AuthN:** There are no auth implementation in the `2024-11-05` specification. This includes:
* [Authenticated Parameters](../resources/tools/_index.md#authenticated-parameters)
* [Authorized Invocations](../resources/tools/_index.md#authorized-invocations)
* **Toolsets**: MCP does not have the concept of toolset. Hence, all tools are automatically loaded when using Toolbox with MCP.
* **Notifications:** Currently, editing Toolbox Tools requires a server restart. Clients should reload tools on disconnect to get the latest version.
## Connecting to Toolbox with an MCP client
### Before you begin
{{< notice note >}}
MCP is only compatible with Toolbox version 0.3.0 and above.
{{< /notice >}}
1. [Install](../getting-started/introduction/_index.md#installing-the-server) Toolbox version 0.3.0+.
1. Make sure you've set up and initialized your database.
1. [Set up](../getting-started/configure.md) your `tools.yaml` file.
### Connecting via HTTP
Toolbox supports the HTTP transport protocol with and without SSE.
{{< tabpane text=true >}} {{% tab header="HTTP with SSE" lang="en" %}}
Add the following configuration to your MCP client configuration:
```bash
{
"mcpServers": {
"toolbox": {
"type": "sse",
"url": "http://127.0.0.1:5000/mcp/sse",
}
}
}
```
{{% /tab %}} {{% tab header="HTTP POST" lang="en" %}}
Connect to Toolbox HTTP POST via `http://127.0.0.1:5000/mcp`.
{{% /tab %}} {{< /tabpane >}}
### Using the MCP Inspector with Toolbox
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testing and debugging Toolbox server.
1. [Run Toolbox](../getting-started/introduction/_index.md#running-the-server).
1. In a separate terminal, run Inspector directly through `npx`:
```bash
npx @modelcontextprotocol/inspector
```
1. For `Transport Type` dropdown menu, select `SSE`.
1. For `URL`, type in `http://127.0.0.1:5000/mcp/sse`.
1. Click the `Connect` button. Voila! You should be able to inspect your toolbox
tools!

View File

@@ -24,8 +24,9 @@ import (
)
func TestToolsetEndpoint(t *testing.T) {
toolsMap, toolsets := setUpResources(t)
ts, shutdown := setUpServer(t, toolsMap, toolsets)
mockTools := []MockTool{tool1, tool2}
toolsMap, toolsets := setUpResources(t, mockTools)
ts, shutdown := setUpServer(t, "api", toolsMap, toolsets)
defer shutdown()
// wantResponse is a struct for checks against test cases
@@ -118,8 +119,9 @@ func TestToolsetEndpoint(t *testing.T) {
}
func TestToolGetEndpoint(t *testing.T) {
toolsMap, toolsets := setUpResources(t)
ts, shutdown := setUpServer(t, toolsMap, toolsets)
mockTools := []MockTool{tool1, tool2}
toolsMap, toolsets := setUpResources(t, mockTools)
ts, shutdown := setUpServer(t, "api", toolsMap, toolsets)
defer shutdown()
// wantResponse is a struct for checks against test cases

View File

@@ -21,8 +21,10 @@ import (
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"github.com/go-chi/chi/v5"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/telemetry"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -38,6 +40,7 @@ type MockTool struct {
Name string
Description string
Params []tools.Parameter
manifest tools.Manifest
}
func (t MockTool) Invoke(tools.ParamValues) ([]any, error) {
@@ -61,6 +64,29 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t MockTool) McpManifest() tools.McpManifest {
properties := make(map[string]tools.ParameterMcpManifest)
required := make([]string, 0)
for _, p := range t.Params {
name := p.GetName()
properties[name] = p.McpManifest()
required = append(required, name)
}
toolsSchema := tools.McpToolsSchema{
Type: "object",
Properties: properties,
Required: required,
}
return tools.McpManifest{
Name: t.Name,
Description: t.Description,
InputSchema: toolsSchema,
}
}
var tool1 = MockTool{
Name: "no_params",
Params: []tools.Parameter{},
@@ -74,15 +100,29 @@ var tool2 = MockTool{
},
}
var tool3 = MockTool{
Name: "array_param",
Description: "some description",
Params: tools.Parameters{
tools.NewArrayParameter("my_array", "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
},
}
// setUpResources setups resources to test against
func setUpResources(t *testing.T) (map[string]tools.Tool, map[string]tools.Toolset) {
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
func setUpResources(t *testing.T, mockTools []MockTool) (map[string]tools.Tool, map[string]tools.Toolset) {
toolsMap := make(map[string]tools.Tool)
var allTools []string
for _, tool := range mockTools {
tool.manifest = tool.Manifest()
toolsMap[tool.Name] = tool
allTools = append(allTools, tool.Name)
}
toolsets := make(map[string]tools.Toolset)
for name, l := range map[string][]string{
"": {tool1.Name, tool2.Name},
"tool1_only": {tool1.Name},
"tool2_only": {tool2.Name},
"": allTools,
"tool1_only": {allTools[0]},
"tool2_only": {allTools[1]},
} {
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
m, err := tc.Initialize(fakeVersionString, toolsMap)
@@ -95,7 +135,7 @@ func setUpResources(t *testing.T) (map[string]tools.Tool, map[string]tools.Tools
}
// setUpServer create a new server with tools and toolsets that are given
func setUpServer(t *testing.T, tools map[string]tools.Tool, toolsets map[string]tools.Toolset) (*httptest.Server, func()) {
func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, toolsets map[string]tools.Toolset) (*httptest.Server, func()) {
ctx, cancel := context.WithCancel(context.Background())
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
@@ -113,10 +153,26 @@ func setUpServer(t *testing.T, tools map[string]tools.Tool, toolsets map[string]
t.Fatalf("unable to create custom metrics: %s", err)
}
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, tools: tools, toolsets: toolsets}
r, err := apiRouter(&server)
if err != nil {
t.Fatalf("unable to initialize router: %s", err)
sseManager := &sseManager{
mu: sync.RWMutex{},
sseSessions: make(map[string]*sseSession),
}
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: tools, toolsets: toolsets}
var r chi.Router
switch router {
case "api":
r, err = apiRouter(&server)
if err != nil {
t.Fatalf("unable to initialize api router: %s", err)
}
case "mcp":
r, err = mcpRouter(&server)
if err != nil {
t.Fatalf("unable to initialize mcp router: %s", err)
}
default:
t.Fatalf("unknown router")
}
ts := httptest.NewServer(r)
shutdown := func() {

View File

@@ -29,6 +29,8 @@ const (
toolsetGetCountName = "toolbox.server.toolset.get.count"
toolGetCountName = "toolbox.server.tool.get.count"
toolInvokeCountName = "toolbox.server.tool.invoke.count"
mcpSseCountName = "toolbox.server.mcp.sse.count"
mcpPostCountName = "toolbox.server.mcp.post.count"
)
// Instrumentation defines the telemetry instrumentation for toolbox
@@ -38,6 +40,8 @@ type Instrumentation struct {
ToolsetGet metric.Int64Counter
ToolGet metric.Int64Counter
ToolInvoke metric.Int64Counter
McpSse metric.Int64Counter
McpPost metric.Int64Counter
}
func CreateTelemetryInstrumentation(versionString string) (*Instrumentation, error) {
@@ -74,12 +78,32 @@ func CreateTelemetryInstrumentation(versionString string) (*Instrumentation, err
return nil, fmt.Errorf("unable to create %s metric: %w", toolInvokeCountName, err)
}
mcpSse, err := meter.Int64Counter(
mcpSseCountName,
metric.WithDescription("Number of MCP SSE connection requests."),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, fmt.Errorf("unable to create %s metric: %w", mcpSseCountName, err)
}
mcpPost, err := meter.Int64Counter(
mcpPostCountName,
metric.WithDescription("Number of MCP Post calls."),
metric.WithUnit("{call}"),
)
if err != nil {
return nil, fmt.Errorf("unable to create %s metric: %w", mcpPostCountName, err)
}
instrumentation := &Instrumentation{
Tracer: tracer,
meter: meter,
ToolsetGet: toolsetGet,
ToolGet: toolGet,
ToolInvoke: toolInvoke,
McpSse: mcpSse,
McpPost: mcpPost,
}
return instrumentation, nil
}

363
internal/server/mcp.go Normal file
View File

@@ -0,0 +1,363 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
)
type sseSession struct {
sessionId string
writer http.ResponseWriter
flusher http.Flusher
done chan struct{}
eventQueue chan string
}
// sseManager manages and control access to sse sessions
type sseManager struct {
mu sync.RWMutex
sseSessions map[string]*sseSession
}
func (m *sseManager) get(id string) (*sseSession, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
session, ok := m.sseSessions[id]
return session, ok
}
func (m *sseManager) add(id string, session *sseSession) {
m.mu.Lock()
m.sseSessions[id] = session
m.mu.Unlock()
}
func (m *sseManager) remove(id string) {
m.mu.Lock()
delete(m.sseSessions, id)
m.mu.Unlock()
}
// mcpRouter creates a router that represents the routes under /mcp
func mcpRouter(s *Server) (chi.Router, error) {
r := chi.NewRouter()
r.Use(middleware.AllowContentType("application/json"))
r.Use(middleware.StripSlashes)
r.Use(render.SetContentType(render.ContentTypeJSON))
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
return r, nil
}
// sseHandler handles sse initialization and message.
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse")
r = r.WithContext(ctx)
sessionId := uuid.New().String()
span.SetAttributes(attribute.String("session_id", sessionId))
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
var err error
defer func() {
if err != nil {
span.SetStatus(codes.Error, err.Error())
}
span.End()
status := "success"
if err != nil {
status = "error"
}
s.instrumentation.McpSse.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()
flusher, ok := w.(http.Flusher)
if !ok {
err = fmt.Errorf("unable to retrieve flusher for sse")
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
}
session := &sseSession{
sessionId: sessionId,
writer: w,
flusher: flusher,
done: make(chan struct{}),
eventQueue: make(chan string, 100),
}
s.sseManager.add(sessionId, session)
defer s.sseManager.remove(sessionId)
// send initial endpoint event
messageEndpoint := fmt.Sprintf("http://%s/mcp?sessionId=%s", r.Host, sessionId)
s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint))
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint)
flusher.Flush()
clientClose := r.Context().Done()
for {
select {
// Ensure that only a single responses are written at once
case event := <-session.eventQueue:
fmt.Fprint(w, event)
s.logger.DebugContext(ctx, fmt.Sprintf("sending event: %s", event))
flusher.Flush()
// channel for client disconnection
case <-clientClose:
close(session.done)
s.logger.DebugContext(ctx, "client disconnected")
return
}
}
}
// mcpHandler handles all mcp messages.
func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
r = r.WithContext(ctx)
var id, toolName, method string
var err error
defer func() {
if err != nil {
span.SetStatus(codes.Error, err.Error())
}
span.End()
status := "success"
if err != nil {
status = "error"
}
s.instrumentation.McpPost.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)),
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
metric.WithAttributes(attribute.String("toolbox.method", method)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()
// Read and returns a body from io.Reader
body, err := io.ReadAll(r.Body)
if err != nil {
// Generate a new uuid if unable to decode
id = uuid.New().String()
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
}
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
var baseMessage struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Id mcp.RequestId `json:"id,omitempty"`
}
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
// Generate a new uuid if unable to decode
id := uuid.New().String()
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
return
}
// Check if method is present
if baseMessage.Method == "" {
err = fmt.Errorf("method not found")
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil))
return
}
// Check for JSON-RPC 2.0
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
err = fmt.Errorf("invalid json-rpc version")
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil))
return
}
// Check if message is a notification
if baseMessage.Id == nil {
id = ""
var notification mcp.JSONRPCNotification
if err = json.Unmarshal(body, &notification); err != nil {
err = fmt.Errorf("invalid notification request: %w", err)
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.PARSE_ERROR, err.Error(), nil))
}
// Notifications do not expect a response
// Toolbox doesn't do anything with notifications yet
w.WriteHeader(http.StatusAccepted)
return
}
id = fmt.Sprintf("%s", baseMessage.Id)
method = baseMessage.Method
var res mcp.JSONRPCMessage
switch baseMessage.Method {
case "initialize":
var req mcp.InitializeRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp initialize request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
result := mcp.Initialize(s.version)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
case "tools/list":
var req mcp.ListToolsRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
toolset, ok := s.toolsets[""]
if !ok {
err = fmt.Errorf("toolset does not exist")
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
result := mcp.ToolsList(toolset)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
case "tools/call":
var req mcp.CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
toolName = req.Params.Name
toolArgument := req.Params.Arguments
tool, ok := s.tools[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
break
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
break
}
var data map[string]any
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
break
}
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
// Since MCP doesn't support auth, an empty map will be use every time.
claimsFromAuth := make(map[string]map[string]any)
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
break
}
result := mcp.ToolCall(tool, params)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
default:
err = fmt.Errorf("invalid method %s", baseMessage.Method)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil)
}
// retrieve sse session
sessionId := r.URL.Query().Get("sessionId")
session, ok := s.sseManager.get(sessionId)
if !ok {
s.logger.DebugContext(ctx, "sse session not available")
} else {
// queue sse event
eventData, _ := json.Marshal(res)
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
s.logger.DebugContext(ctx, "event queue successful")
case <-session.done:
s.logger.DebugContext(ctx, "session is close")
default:
s.logger.DebugContext(ctx, "unable to add to event queue")
}
}
// send HTTP response
render.JSON(w, r, res)
}
// newJSONRPCError is the response sent back when an error has been encountered in mcp.
func newJSONRPCError(id mcp.RequestId, code int, message string, data any) mcp.JSONRPCError {
return mcp.JSONRPCError{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: id,
Error: mcp.McpError{
Code: code,
Message: message,
Data: data,
},
}
}

View File

@@ -0,0 +1,74 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mcp
import (
"encoding/json"
"fmt"
"github.com/googleapis/genai-toolbox/internal/tools"
)
func Initialize(version string) InitializeResult {
toolsListChanged := false
result := InitializeResult{
ProtocolVersion: LATEST_PROTOCOL_VERSION,
Capabilities: ServerCapabilities{
Tools: &ListChanged{
ListChanged: &toolsListChanged,
},
},
ServerInfo: Implementation{
Name: SERVER_NAME,
Version: version,
},
}
return result
}
// ToolsList return a ListToolsResult
func ToolsList(toolset tools.Toolset) ListToolsResult {
mcpManifest := toolset.McpManifest
result := ListToolsResult{
Tools: mcpManifest,
}
return result
}
// ToolCall runs tool invocation and return a CallToolResult
func ToolCall(tool tools.Tool, params tools.ParamValues) CallToolResult {
res, err := tool.Invoke(params)
if err != nil {
text := TextContent{
Type: "text",
Text: err.Error(),
}
return CallToolResult{Content: []TextContent{text}, IsError: true}
}
content := make([]TextContent, 0)
for _, d := range res {
text := TextContent{Type: "text"}
dM, err := json.Marshal(d)
if err != nil {
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
} else {
text.Text = string(dM)
}
content = append(content, text)
}
return CallToolResult{Content: content}
}

View File

@@ -0,0 +1,295 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mcp
import (
"github.com/googleapis/genai-toolbox/internal/tools"
)
// SERVER_NAME is the server name used in Implementation.
const SERVER_NAME = "Toolbox"
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
const LATEST_PROTOCOL_VERSION = "2024-11-05"
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.
const JSONRPC_VERSION = "2.0"
// Standard JSON-RPC error codes
const (
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
METHOD_NOT_FOUND = -32601
INVALID_PARAMS = -32602
INTERNAL_ERROR = -32603
)
// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError.
type JSONRPCMessage interface{}
// ProgressToken is used to associate progress notifications with the original request.
type ProgressToken interface{}
// Request represents a bidirectional message with method and parameters expecting a response.
type Request struct {
Method string `json:"method"`
Params struct {
Meta struct {
// If specified, the caller is requesting out-of-band progress
// notifications for this request (as represented by
// notifications/progress). The value of this parameter is an
// opaque token that will be attached to any subsequent
// notifications. The receiver is not obligated to provide these
// notifications.
ProgressToken ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
} `json:"params,omitempty"`
}
// Notification is a one-way message requiring no response.
type Notification struct {
Method string `json:"method"`
Params struct {
Meta map[string]interface{} `json:"_meta,omitempty"`
} `json:"params,omitempty"`
}
// Result represents a response for the request query.
type Result struct {
// This result property is reserved by the protocol to allow clients and
// servers to attach additional metadata to their responses.
Meta map[string]interface{} `json:"_meta,omitempty"`
}
// RequestId is a uniquely identifying ID for a request in JSON-RPC.
// It can be any JSON-serializable value, typically a number or string.
type RequestId interface{}
// JSONRPCRequest represents a request that expects a response.
type JSONRPCRequest struct {
Jsonrpc string `json:"jsonrpc"`
Id RequestId `json:"id"`
Request
Params any `json:"params,omitempty"`
}
// JSONRPCNotification represents a notification which does not expect a response.
type JSONRPCNotification struct {
Jsonrpc string `json:"jsonrpc"`
Notification
}
// JSONRPCResponse represents a successful (non-error) response to a request.
type JSONRPCResponse struct {
Jsonrpc string `json:"jsonrpc"`
Id RequestId `json:"id"`
Result interface{} `json:"result"`
}
// McpError represents the error content.
type McpError struct {
// The error type that occurred.
Code int `json:"code"`
// A short description of the error. The message SHOULD be limited
// to a concise single sentence.
Message string `json:"message"`
// Additional information about the error. The value of this member
// is defined by the sender (e.g. detailed error information, nested errors etc.).
Data interface{} `json:"data,omitempty"`
}
// JSONRPCError represents a non-successful (error) response to a request.
type JSONRPCError struct {
Jsonrpc string `json:"jsonrpc"`
Id RequestId `json:"id"`
Error McpError `json:"error"`
}
/* Empty result */
// EmptyResult represents a response that indicates success but carries no data.
type EmptyResult Result
/* Initialization */
// Params to define MCP Client during initialize request.
type InitializeParams struct {
// The latest version of the Model Context Protocol that the client supports.
// The client MAY decide to support older versions as well.
ProtocolVersion string `json:"protocolVersion"`
Capabilities ClientCapabilities `json:"capabilities"`
ClientInfo Implementation `json:"clientInfo"`
}
// InitializeRequest is sent from the client to the server when it first
// connects, asking it to begin initialization.
type InitializeRequest struct {
Request
Params InitializeParams `json:"params"`
}
// InitializeResult is sent after receiving an initialize request from the
// client.
type InitializeResult struct {
Result
// The version of the Model Context Protocol that the server wants to use.
// This may not match the version that the client requested. If the client cannot
// support this version, it MUST disconnect.
ProtocolVersion string `json:"protocolVersion"`
Capabilities ServerCapabilities `json:"capabilities"`
ServerInfo Implementation `json:"serverInfo"`
// Instructions describing how to use the server and its features.
//
// This can be used by clients to improve the LLM's understanding of
// available tools, resources, etc. It can be thought of like a "hint" to the model.
// For example, this information MAY be added to the system prompt.
Instructions string `json:"instructions,omitempty"`
}
// InitializedNotification is sent from the client to the server after
// initialization has finished.
type InitializedNotification struct {
Notification
}
// ListChange represents whether the server supports notification for changes to the capabilities.
type ListChanged struct {
ListChanged *bool `json:"listChanged,omitempty"`
}
// ClientCapabilities represents capabilities a client may support. Known
// capabilities are defined here, in this schema, but this is not a closed set: any
// client can define its own, additional capabilities.
type ClientCapabilities struct {
// Experimental, non-standard capabilities that the client supports.
Experimental map[string]interface{} `json:"experimental,omitempty"`
// Present if the client supports listing roots.
Roots *ListChanged `json:"roots,omitempty"`
// Present if the client supports sampling from an LLM.
Sampling struct{} `json:"sampling,omitempty"`
}
// ServerCapabilities represents capabilities that a server may support. Known
// capabilities are defined here, in this schema, but this is not a closed set: any
// server can define its own, additional capabilities.
type ServerCapabilities struct {
Tools *ListChanged `json:"tools,omitempty"`
}
// Implementation describes the name and version of an MCP implementation.
type Implementation struct {
Name string `json:"name"`
Version string `json:"version"`
}
/* Pagination */
// Cursor is an opaque token used to represent a cursor for pagination.
type Cursor string
type PaginatedRequest struct {
Request
Params struct {
// An opaque token representing the current pagination position.
// If provided, the server should return results starting after this cursor.
Cursor Cursor `json:"cursor,omitempty"`
} `json:"params,omitempty"`
}
type PaginatedResult struct {
Result
// An opaque token representing the pagination position after the last returned result.
// If present, there may be more results available.
NextCursor Cursor `json:"nextCursor,omitempty"`
}
/* Tools */
// Sent from the client to request a list of tools the server has.
type ListToolsRequest struct {
PaginatedRequest
}
// The server's response to a tools/list request from the client.
type ListToolsResult struct {
PaginatedResult
Tools []tools.McpManifest `json:"tools"`
}
// Used by the client to invoke a tool provided by the server.
type CallToolRequest struct {
Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params,omitempty"`
}
// The sender or recipient of messages and data in a conversation.
type Role string
const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
// Base for objects that include optional annotations for the client.
// The client can use annotations to inform how objects are used or displayed
type Annotated struct {
Annotations *struct {
// Describes who the intended customer of this object or data is.
// It can include multiple entries to indicate content useful for multiple
// audiences (e.g., `["user", "assistant"]`).
Audience []Role `json:"audience,omitempty"`
// Describes how important this data is for operating the server.
//
// A value of 1 means "most important," and indicates that the data is
// effectively required, while 0 means "least important," and indicates that
// the data is entirely optional.
//
// @TJS-type number
// @minimum 0
// @maximum 1
Priority float64 `json:"priority,omitempty"`
} `json:"annotations,omitempty"`
}
// TextContent represents text provided to or from an LLM.
type TextContent struct {
Annotated
Type string `json:"type"`
// The text content of the message.
Text string `json:"text"`
}
// The server's response to a tool call.
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
type CallToolResult struct {
Result
// Could be either a TextContent, ImageContent, or EmbeddedResources
// For Toolbox, we will only be sending TextContent
Content []TextContent `json:"content"`
// Whether the tool call ended in an error.
// If not set, this is assumed to be false (the call was successful).
IsError bool `json:"isError,omitempty"`
}

258
internal/server/mcp_test.go Normal file
View File

@@ -0,0 +1,258 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"encoding/json"
"net/http"
"reflect"
"strings"
"testing"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
)
const jsonrpcVersion = "2.0"
const protocolVersion = "2024-11-05"
const serverName = "Toolbox"
var tool1InputSchema = map[string]any{
"type": "object",
"properties": map[string]any{},
"required": []any{},
}
var tool2InputSchema = map[string]any{
"type": "object",
"properties": map[string]any{
"param1": map[string]any{"type": "integer", "description": "This is the first parameter."},
"param2": map[string]any{"type": "integer", "description": "This is the second parameter."},
},
"required": []any{"param1", "param2"},
}
var tool3InputSchema = map[string]any{
"type": "object",
"properties": map[string]any{
"my_array": map[string]any{
"type": "array",
"description": "this param is an array of strings",
"items": map[string]any{"type": "string", "description": "string item"},
},
},
"required": []any{"my_array"},
}
func TestMcpEndpoint(t *testing.T) {
mockTools := []MockTool{tool1, tool2, tool3}
toolsMap, toolsets := setUpResources(t, mockTools)
ts, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
defer shutdown()
testCases := []struct {
name string
isErr bool
body mcp.JSONRPCRequest
want map[string]any
}{
{
name: "initialize",
body: mcp.JSONRPCRequest{
Jsonrpc: jsonrpcVersion,
Id: "mcp-initialize",
Request: mcp.Request{
Method: "initialize",
},
},
want: map[string]any{
"jsonrpc": "2.0",
"id": "mcp-initialize",
"result": map[string]any{
"protocolVersion": protocolVersion,
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
},
},
},
{
name: "basic notification",
body: mcp.JSONRPCRequest{
Jsonrpc: jsonrpcVersion,
Request: mcp.Request{
Method: "notification",
},
},
},
{
name: "tools/list",
body: mcp.JSONRPCRequest{
Jsonrpc: jsonrpcVersion,
Id: "tools-list",
Request: mcp.Request{
Method: "tools/list",
},
},
want: map[string]any{
"jsonrpc": "2.0",
"id": "tools-list",
"result": map[string]any{
"tools": []any{
map[string]any{
"name": "no_params",
"inputSchema": tool1InputSchema,
},
map[string]any{
"name": "some_params",
"inputSchema": tool2InputSchema,
},
map[string]any{
"name": "array_param",
"description": "some description",
"inputSchema": tool3InputSchema,
},
},
},
},
},
{
name: "missing method",
isErr: true,
body: mcp.JSONRPCRequest{
Jsonrpc: jsonrpcVersion,
Id: "missing-method",
Request: mcp.Request{},
},
want: map[string]any{
"jsonrpc": "2.0",
"id": "missing-method",
"error": map[string]any{
"code": -32601.0,
"message": "method not found",
},
},
},
{
name: "invalid method",
isErr: true,
body: mcp.JSONRPCRequest{
Jsonrpc: jsonrpcVersion,
Id: "invalid-method",
Request: mcp.Request{
Method: "foo",
},
},
want: map[string]any{
"jsonrpc": "2.0",
"id": "invalid-method",
"error": map[string]any{
"code": -32601.0,
"message": "invalid method foo",
},
},
},
{
name: "invalid jsonrpc version",
isErr: true,
body: mcp.JSONRPCRequest{
Jsonrpc: "1.0",
Id: "invalid-jsonrpc-version",
Request: mcp.Request{
Method: "foo",
},
},
want: map[string]any{
"jsonrpc": "2.0",
"id": "invalid-jsonrpc-version",
"error": map[string]any{
"code": -32600.0,
"message": "invalid json-rpc version",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reqMarshal, err := json.Marshal(tc.body)
if err != nil {
t.Fatalf("unexpected error during marshaling of body")
}
resp, body, err := runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(reqMarshal))
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
// Notifications don't expect a response.
if tc.want != nil {
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
}
var got map[string]any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("unexpected error unmarshalling body: %s", err)
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("unexpected response: got %+v, want %+v", got, tc.want)
}
}
})
}
}
func TestSseEndpoint(t *testing.T) {
ts, shutdown := setUpServer(t, "mcp", nil, nil)
defer shutdown()
contentType := "text/event-stream"
cacheControl := "no-cache"
connection := "keep-alive"
accessControlAllowOrigin := "*"
wantEvent := "event: endpoint"
t.Run("test sse endpoint", func(t *testing.T) {
resp, err := http.Get(ts.URL + "/sse")
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
defer resp.Body.Close()
if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType {
t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType)
}
if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl {
t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl)
}
if gotConnection := resp.Header.Get("Connection"); gotConnection != connection {
t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection)
}
if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin {
t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin)
}
buffer := make([]byte, 1024)
n, err := resp.Body.Read(buffer)
if err != nil {
t.Fatalf("unable to read response: %s", err)
}
endpointEvent := string(buffer[:n])
if !strings.Contains(endpointEvent, wantEvent) {
t.Fatalf("unexpected event: got %s", endpointEvent)
}
})
}

View File

@@ -20,6 +20,7 @@ import (
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/go-chi/chi/v5"
@@ -42,6 +43,7 @@ type Server struct {
root chi.Router
logger log.Logger
instrumentation *Instrumentation
sseManager *sseManager
sources map[string]sources.Source
authServices map[string]auth.AuthService
@@ -203,12 +205,18 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
srv := &http.Server{Addr: addr, Handler: r}
sseManager := &sseManager{
mu: sync.RWMutex{},
sseSessions: make(map[string]*sseSession),
}
s := &Server{
version: cfg.Version,
srv: srv,
root: r,
logger: l,
instrumentation: instrumentation,
sseManager: sseManager,
sources: sourcesMap,
authServices: authServicesMap,
@@ -221,6 +229,11 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
return nil, err
}
r.Mount("/api", apiR)
mcpR, err := mcpRouter(s)
if err != nil {
return nil, err
}
r.Mount("/mcp", mcpR)
// default endpoint for validating server is running
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("🧰 Hello, World! 🧰"))

View File

@@ -65,6 +65,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
@@ -75,6 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
IsQuery: cfg.IsQuery,
Timeout: cfg.Timeout,
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -92,6 +99,7 @@ type Tool struct {
Timeout string
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
@@ -127,6 +135,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -91,7 +91,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// Create a slice for all parameters
allParameters := slices.Concat(cfg.BodyParams, cfg.HeaderParams, cfg.QueryParams)
// Create parameter manifest
// Create parameter MCP manifest
paramManifest := slices.Concat(
cfg.QueryParams.Manifest(),
cfg.BodyParams.Manifest(),
@@ -101,6 +101,36 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
paramManifest = make([]tools.ParameterManifest, 0)
}
queryMcpManifest := cfg.QueryParams.McpManifest()
bodyMcpManifest := cfg.BodyParams.McpManifest()
headerMcpManifest := cfg.HeaderParams.McpManifest()
// Concatenate parameters for MCP `required` field
concatRequiredManifest := slices.Concat(
queryMcpManifest.Required,
bodyMcpManifest.Required,
headerMcpManifest.Required,
)
// Concatenate parameters for MCP `properties` field
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
for name, p := range queryMcpManifest.Properties {
concatPropertiesManifest[name] = p
}
for name, p := range bodyMcpManifest.Properties {
concatPropertiesManifest[name] = p
}
for name, p := range headerMcpManifest.Properties {
concatPropertiesManifest[name] = p
}
// Create a new McpToolsSchema with all parameters
paramMcpManifest := tools.McpToolsSchema{
Type: "object",
Properties: concatPropertiesManifest,
Required: concatRequiredManifest,
}
// Verify there are no duplicate parameter names
seenNames := make(map[string]bool)
for _, param := range paramManifest {
@@ -110,6 +140,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
seenNames[param.Name] = true
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: paramMcpManifest,
}
// finish tool setup
return Tool{
Name: cfg.Name,
@@ -125,6 +161,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Client: s.Client,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest},
mcpManifest: mcpManifest,
}, nil
}
@@ -146,8 +183,9 @@ type Tool struct {
HeaderParams tools.Parameters `yaml:"headerParams"`
AllParams tools.Parameters `yaml:"allParams"`
Client *http.Client
manifest tools.Manifest
Client *http.Client
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// helper function to convert a parameter to JSON formatted string.
@@ -271,6 +309,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -68,6 +68,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
@@ -77,6 +83,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
Db: s.MSSQLDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -90,9 +97,10 @@ type Tool struct {
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Db *sql.DB
Statement string
manifest tools.Manifest
Db *sql.DB
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
@@ -157,6 +165,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -67,6 +67,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
@@ -76,6 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
Pool: s.MySQLPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -89,9 +96,10 @@ type Tool struct {
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Pool *sql.DB
Statement string
manifest tools.Manifest
Pool *sql.DB
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
@@ -152,6 +160,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -66,15 +66,22 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
Driver: s.Neo4jDriver(),
Database: s.Neo4jDatabase(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
Driver: s.Neo4jDriver(),
Database: s.Neo4jDatabase(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -88,10 +95,11 @@ type Tool struct {
Parameters tools.Parameters `yaml:"parameters"`
AuthRequired []string `yaml:"authRequired"`
Driver neo4j.DriverWithContext
Database string
Statement string
manifest tools.Manifest
Driver neo4j.DriverWithContext
Database string
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
@@ -127,6 +135,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -154,6 +154,14 @@ type Parameter interface {
GetAuthServices() []ParamAuthService
Parse(any) (any, error)
Manifest() ParameterManifest
McpManifest() ParameterMcpManifest
}
// McpToolsSchema is the representation of input schema for McpManifest.
type McpToolsSchema struct {
Type string `json:"type"`
Properties map[string]ParameterMcpManifest `json:"properties"`
Required []string `json:"required"`
}
// Parameters is a type used to allow unmarshal a list of parameters
@@ -265,6 +273,24 @@ func (ps Parameters) Manifest() []ParameterManifest {
return rtn
}
func (ps Parameters) McpManifest() McpToolsSchema {
properties := make(map[string]ParameterMcpManifest)
required := make([]string, 0)
for _, p := range ps {
name := p.GetName()
properties[name] = p.McpManifest()
// all parameters are added to the required field
required = append(required, name)
}
return McpToolsSchema{
Type: "object",
Properties: properties,
Required: required,
}
}
// ParameterManifest represents parameters when served as part of a ToolManifest.
type ParameterManifest struct {
Name string `json:"name"`
@@ -274,6 +300,13 @@ type ParameterManifest struct {
Items *ParameterManifest `json:"items,omitempty"`
}
// ParameterMcpManifest represents properties when served as part of a ToolMcpManifest.
type ParameterMcpManifest struct {
Type string `json:"type"`
Description string `json:"description"`
Items *ParameterMcpManifest `json:"items,omitempty"`
}
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
type CommonParameter struct {
Name string `yaml:"name" validate:"required"`
@@ -308,6 +341,14 @@ func (p *CommonParameter) Manifest() ParameterManifest {
}
}
// McpManifest returns the MCP manifest for the Parameter.
func (p *CommonParameter) McpManifest() ParameterMcpManifest {
return ParameterMcpManifest{
Type: p.Type,
Description: p.Desc,
}
}
// ParseTypeError is a custom error for incorrectly typed Parameters.
type ParseTypeError struct {
Name string
@@ -611,3 +652,18 @@ func (p *ArrayParameter) Manifest() ParameterManifest {
Items: &items,
}
}
// McpManifest returns the MCP manifest for the ArrayParameter.
func (p *ArrayParameter) McpManifest() ParameterMcpManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
items := p.Items.McpManifest()
return ParameterMcpManifest{
Type: p.Type,
Description: p.Desc,
Items: &items,
}
}

View File

@@ -847,6 +847,52 @@ func TestParamManifest(t *testing.T) {
}
}
func TestParamMcpManifest(t *testing.T) {
tcs := []struct {
name string
in tools.Parameter
want tools.ParameterMcpManifest
}{
{
name: "string",
in: tools.NewStringParameter("foo-string", "bar"),
want: tools.ParameterMcpManifest{Type: "string", Description: "bar"},
},
{
name: "int",
in: tools.NewIntParameter("foo-int", "bar"),
want: tools.ParameterMcpManifest{Type: "integer", Description: "bar"},
},
{
name: "float",
in: tools.NewFloatParameter("foo-float", "bar"),
want: tools.ParameterMcpManifest{Type: "float", Description: "bar"},
},
{
name: "boolean",
in: tools.NewBooleanParameter("foo-bool", "bar"),
want: tools.ParameterMcpManifest{Type: "boolean", Description: "bar"},
},
{
name: "array",
in: tools.NewArrayParameter("foo-array", "bar", tools.NewStringParameter("foo-string", "bar")),
want: tools.ParameterMcpManifest{
Type: "array",
Description: "bar",
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got := tc.in.McpManifest()
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("unexpected manifest: got %+v, want %+v", got, tc.want)
}
})
}
}
func TestFailParametersUnmarshal(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {

View File

@@ -69,6 +69,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
@@ -78,6 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -91,9 +98,10 @@ type Tool struct {
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Pool *pgxpool.Pool
Statement string
manifest tools.Manifest
Pool *pgxpool.Pool
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(params tools.ParamValues) ([]any, error) {
@@ -129,6 +137,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -68,6 +68,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
}
// finish tool setup
t := Tool{
Name: cfg.Name,
@@ -78,6 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Client: s.SpannerClient(),
dialect: s.DatabaseDialect(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -91,10 +98,11 @@ type Tool struct {
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Client *spanner.Client
dialect string
Statement string
manifest tools.Manifest
Client *spanner.Client
dialect string
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func getMapParams(params tools.ParamValues, dialect string) (map[string]interface{}, error) {
@@ -157,6 +165,10 @@ func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -29,6 +29,7 @@ type Tool interface {
Invoke(ParamValues) ([]any, error)
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
Manifest() Manifest
McpManifest() McpManifest
Authorized([]string) bool
}
@@ -38,6 +39,16 @@ type Manifest struct {
Parameters []ParameterManifest `json:"parameters"`
}
// Definition for a tool the MCP client can call.
type McpManifest struct {
// The name of the tool.
Name string `json:"name"`
// A human-readable description of the tool.
Description string `json:"description,omitempty"`
// A JSON Schema object defining the expected parameters for the tool.
InputSchema McpToolsSchema `json:"inputSchema,omitempty"`
}
// Helper function that returns if a tool invocation request is authorized
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
if len(authRequiredSources) == 0 {

View File

@@ -24,9 +24,10 @@ type ToolsetConfig struct {
}
type Toolset struct {
Name string `yaml:"name"`
Tools []*Tool `yaml:",inline"`
Manifest ToolsetManifest `yaml:",inline"`
Name string `yaml:"name"`
Tools []*Tool `yaml:",inline"`
Manifest ToolsetManifest `yaml:",inline"`
McpManifest []McpManifest `yaml:",inline"`
}
type ToolsetManifest struct {
@@ -54,6 +55,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool
}
toolset.Tools = append(toolset.Tools, &tool)
toolset.Manifest.ToolsManifest[toolName] = tool.Manifest()
toolset.McpManifest = append(toolset.McpManifest, tool.McpManifest())
}
return toolset, nil

View File

@@ -163,7 +163,9 @@ func TestAlloyDBToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"?column?\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}
// Test connection with different IP type

View File

@@ -150,7 +150,9 @@ func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}
// Test connection with different IP type

View File

@@ -145,7 +145,9 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"1\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}
// Test connection with different IP type

View File

@@ -149,7 +149,9 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"?column?\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}
// Test connection with different IP type

View File

@@ -93,6 +93,12 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement,
"my-google-auth",
},
},
"my-fail-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Tool to test statement with incorrect syntax.",
"statement": "SELEC 1;",
},
},
}

View File

@@ -122,5 +122,7 @@ func TestMsSQLToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}

View File

@@ -121,5 +121,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"1\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}

View File

@@ -121,5 +121,7 @@ func TestPostgres(t *testing.T) {
RunToolGetTest(t)
select_1_want := "[{\"?column?\":1}]"
fail_invocation_want := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: syntax error at or near \"SELEC\" (SQLSTATE 42601)"}],"isError":true}}`
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}

View File

@@ -28,6 +28,7 @@ import (
"strings"
"testing"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -304,3 +305,136 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
})
}
}
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
func RunMCPToolCallMethod(t *testing.T, fail_invocation_want string) {
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody mcp.JSONRPCRequest
requestHeader map[string]string
want string
}{
{
name: "MCP Invoke my-param-tool",
api: "http://127.0.0.1:5000/mcp",
requestHeader: map[string]string{},
requestBody: mcp.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "my-param-tool",
Request: mcp.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "my-param-tool",
"arguments": map[string]any{
"id": int(3),
"name": "Alice",
},
},
},
want: `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
},
{
name: "MCP Invoke invalid tool",
api: "http://127.0.0.1:5000/mcp",
requestHeader: map[string]string{},
requestBody: mcp.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "invalid-tool",
Request: mcp.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "foo",
"arguments": map[string]any{},
},
},
want: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
},
{
name: "MCP Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/mcp",
requestHeader: map[string]string{},
requestBody: mcp.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "invoke-without-parameter",
Request: mcp.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "my-tool",
"arguments": map[string]any{},
},
},
want: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"invalid tool name: tool with name \"my-tool\" does not exist"}}`,
},
{
name: "MCP Invoke my-tool with insufficient parameters",
api: "http://127.0.0.1:5000/mcp",
requestHeader: map[string]string{},
requestBody: mcp.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "invoke-insufficient-parameter",
Request: mcp.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "my-tool",
"arguments": map[string]any{"id": 1},
},
},
want: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"invalid tool name: tool with name \"my-tool\" does not exist"}}`,
},
{
name: "MCP Invoke my-fail-tool",
api: "http://127.0.0.1:5000/mcp",
requestHeader: map[string]string{},
requestBody: mcp.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "invoke-fail-tool",
Request: mcp.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "my-fail-tool",
"arguments": map[string]any{"id": 1},
},
},
want: fail_invocation_want,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
reqMarshal, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("unexpected error during marshaling of request body")
}
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal))
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read request body: %s", err)
}
defer resp.Body.Close()
got := string(bytes.TrimSpace(respBody))
if got != tc.want {
fmt.Printf("res is %s\n\n", got)
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}