feat: support MCP stdio transport protocol (#607)

Support MCP
[stdio](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio)
transport protocol!

To run stdio with Toolbox, user have to use the `--stdio` flag.

Example of running MCP Toolbox with MCP Inspector via stdio transport
protocol: `npx @modelcontextprotocol/inspector ./toolbox --stdio`.

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
Yuan
2025-05-28 10:10:34 -07:00
committed by GitHub
parent 5292e12588
commit 1702ce1e00
7 changed files with 453 additions and 160 deletions

View File

@@ -71,12 +71,14 @@ type Command struct {
cfg server.ServerConfig
logger log.Logger
tools_file string
inStream io.Reader
outStream io.Writer
errStream io.Writer
}
// NewCommand returns a Command object representing an invocation of the CLI.
func NewCommand(opts ...Option) *Command {
in := os.Stdin
out := os.Stdout
err := os.Stderr
@@ -87,6 +89,7 @@ func NewCommand(opts ...Option) *Command {
}
cmd := &Command{
Command: baseCmd,
inStream: in,
outStream: out,
errStream: err,
}
@@ -98,7 +101,8 @@ func NewCommand(opts ...Option) *Command {
// Set server version
cmd.cfg.Version = versionString
// set baseCmd out and err the same as cmd.
// set baseCmd in, out and err the same as cmd.
baseCmd.SetIn(cmd.inStream)
baseCmd.SetOut(cmd.outStream)
baseCmd.SetErr(cmd.errStream)
@@ -115,6 +119,7 @@ func NewCommand(opts ...Option) *Command {
flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
flags.StringVar(&cmd.cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')")
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
// wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
@@ -163,7 +168,25 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
return toolsFile, nil
}
// updateLogLevel checks if Toolbox have to update the existing log level set by users.
// stdio doesn't support "debug" and "info" logs.
func updateLogLevel(stdio bool, logLevel string) bool {
if stdio {
switch strings.ToUpper(logLevel) {
case log.Debug, log.Info:
return true
default:
return false
}
}
return false
}
func run(cmd *Command) error {
if updateLogLevel(cmd.cfg.Stdio, cmd.cfg.LogLevel.String()) {
cmd.cfg.LogLevel = server.StringLevel(log.Warn)
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
@@ -262,9 +285,16 @@ func run(cmd *Command) error {
srvErr := make(chan error)
go func() {
defer close(srvErr)
err = s.Serve(ctx)
if err != nil {
srvErr <- err
if cmd.cfg.Stdio {
err = s.ServeStdio(ctx, cmd.inStream, cmd.outStream)
if err != nil {
srvErr <- err
}
} else {
err = s.Serve(ctx)
if err != nil {
srvErr <- err
}
}
}()

View File

@@ -163,6 +163,13 @@ func TestServerConfigFlags(t *testing.T) {
TelemetryServiceName: "toolbox-custom",
}),
},
{
desc: "stdio",
args: []string{"--stdio"},
want: withDefaults(server.ServerConfig{
Stdio: true,
}),
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
@@ -858,3 +865,51 @@ func TestEnvVarReplacement(t *testing.T) {
}
}
func TestUpdateLogLevel(t *testing.T) {
tcs := []struct {
desc string
stdio bool
logLevel string
want bool
}{
{
desc: "no stdio",
stdio: false,
logLevel: "info",
want: false,
},
{
desc: "stdio with info log",
stdio: true,
logLevel: "info",
want: true,
},
{
desc: "stdio with debug log",
stdio: true,
logLevel: "debug",
want: true,
},
{
desc: "stdio with warn log",
stdio: true,
logLevel: "warn",
want: false,
},
{
desc: "stdio with error log",
stdio: true,
logLevel: "error",
want: false,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := updateLogLevel(tc.stdio, tc.logLevel)
if got != tc.want {
t.Fatalf("incorrect indication to update log level: got %t, want %t", got, tc.want)
}
})
}
}

View File

@@ -36,6 +36,20 @@ MCP is only compatible with Toolbox version 0.3.0 and above.
1. [Set up](../getting-started/configure.md) your `tools.yaml` file.
### Connecting via Standard Input/Output (stdio)
Toolbox supports the
[stdio](https://modelcontextprotocol.io/docs/concepts/transports#standard-input%2Foutput-stdio)
transport protocol. Users that wish to use stdio will have to include the
`--stdio` flag when running Toolbox.
```bash
./toolbox --stdio
```
When running with stdio, Toolbox will listen via stdio instead of acting as a
remote HTTP server. Logs will be set to the `warn` level by default. `debug` and `info` logs are not
supported with stdio.
### Connecting via HTTP
Toolbox supports the HTTP transport protocol with and without SSE.
@@ -61,8 +75,27 @@ If you would like to connect to a specific toolset, connect via `http://127.0.0.
### Using the MCP Inspector with Toolbox
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testing and debugging Toolbox server.
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for
testing and debugging Toolbox server.
{{< tabpane text=true >}}
{{% tab header="STDIO" lang="en" %}}
1. Run Inspector with Toolbox as a subprocess:
```bash
npx @modelcontextprotocol/inspector ./toolbox --stdio
```
1. For `Transport Type` dropdown menu, select `STDIO`.
1. In `Command`, make sure that it is set to :`./toolbox` (or the correct path to where the Toolbox binary is installed).
1. In `Arguments`, make sure that it's filled with `--stdio`.
1. Click the `Connect` button. It might take awhile to spin up Toolbox. Voila!
You should be able to inspect your toolbox tools!
{{% /tab %}}
{{% tab header="HTTP with SSE" lang="en" %}}
1. [Run Toolbox](../getting-started/introduction/_index.md#running-the-server).
1. In a separate terminal, run Inspector directly through `npx`:
@@ -78,12 +111,13 @@ Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testi
1. Click the `Connect` button. Voila! You should be able to inspect your toolbox
tools!
{{% /tab %}} {{< /tabpane >}}
### Tested Clients
| Client | SSE Works | MCP Config Docs |
|--------|--------|--------|
| Claude Desktop | | Claude Desktop only supports STDIO -- use [`mcp-remote`](https://www.npmjs.com/package/mcp-remote) to proxy. |
| Claude Desktop | | https://modelcontextprotocol.io/quickstart/user#1-download-claude-for-desktop |
| MCP Inspector | ✅ | https://github.com/modelcontextprotocol/inspector |
| Cursor | ✅ | https://docs.cursor.com/context/model-context-protocol |
| Windsurf | ✅ | https://docs.windsurf.com/windsurf/mcp |

View File

@@ -82,6 +82,8 @@ type ServerConfig struct {
TelemetryOTLP string
// TelemetryServiceName defines the value of service.name resource attribute.
TelemetryServiceName string
// Stdio indicates if Toolbox is listening via MCP stdio.
Stdio bool
}
type logFormat string

View File

@@ -15,7 +15,9 @@
package server
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@@ -27,6 +29,7 @@ import (
"github.com/go-chi/render"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
@@ -65,6 +68,101 @@ func (m *sseManager) remove(id string) {
m.mu.Unlock()
}
type stdioSession struct {
server *Server
reader *bufio.Reader
writer io.Writer
}
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
stdioSession := &stdioSession{
server: s,
reader: bufio.NewReader(stdin),
writer: stdout,
}
return stdioSession
}
func (s *stdioSession) Start(ctx context.Context) error {
return s.readInputStream(ctx)
}
// readInputStream reads requests/notifications from MCP clients through stdin
func (s *stdioSession) readInputStream(ctx context.Context) error {
for {
if err := ctx.Err(); err != nil {
return err
}
line, err := s.readLine(ctx)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
res, err := processMcpMessage(ctx, []byte(line), s.server, "")
if err != nil {
// errors during the processing of message will generate a valid MCP Error response.
// server can continue to run.
s.server.logger.ErrorContext(ctx, err.Error())
}
if err = s.write(ctx, res); err != nil {
return err
}
}
}
// readLine process each line within the input stream.
func (s *stdioSession) readLine(ctx context.Context) (string, error) {
readChan := make(chan string, 1)
errChan := make(chan error, 1)
done := make(chan struct{})
defer close(done)
defer close(readChan)
defer close(errChan)
go func() {
select {
case <-done:
return
default:
line, err := s.reader.ReadString('\n')
if err != nil {
select {
case errChan <- err:
case <-done:
}
return
}
select {
case readChan <- line:
case <-done:
}
return
}
}()
select {
// if context is cancelled, return an empty string
case <-ctx.Done():
return "", ctx.Err()
// return error if error is found
case err := <-errChan:
return "", err
// return line if successful
case line := <-readChan:
return line, nil
}
}
// write writes to stdout with response to client
func (s *stdioSession) write(ctx context.Context, response any) error {
res, _ := json.Marshal(response)
_, err := fmt.Fprintf(s.writer, "%s\n", res)
return err
}
// mcpRouter creates a router that represents the routes under /mcp
func mcpRouter(s *Server) (chi.Router, error) {
r := chi.NewRouter()
@@ -74,11 +172,11 @@ func mcpRouter(s *Server) (chi.Router, error) {
r.Use(render.SetContentType(render.ContentTypeJSON))
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
r.Route("/{toolsetName}", func(r chi.Router) {
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
})
return r, nil
@@ -172,16 +270,19 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
}
// mcpHandler handles all mcp messages.
func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
// httpHandler handles all mcp messages.
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
r = r.WithContext(ctx)
ctx = util.WithLogger(r.Context(), s.logger)
toolsetName := chi.URLParam(r, "toolsetName")
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
span.SetAttributes(attribute.String("toolset_name", toolsetName))
var id, toolName, method string
// retrieve sse session id, if applicable
sessionId := r.URL.Query().Get("sessionId")
var err error
defer func() {
if err != nil {
@@ -196,10 +297,7 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
s.instrumentation.McpPost.Add(
r.Context(),
1,
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)),
metric.WithAttributes(attribute.String("toolbox.tool.name", toolName)),
metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)),
metric.WithAttributes(attribute.String("toolbox.method", method)),
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
)
}()
@@ -207,166 +305,25 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
// Read and returns a body from io.Reader
body, err := io.ReadAll(r.Body)
if err != nil {
// Generate a new uuid if unable to decode
id = uuid.New().String()
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
}
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
var baseMessage struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Id mcp.RequestId `json:"id,omitempty"`
}
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
// Generate a new uuid if unable to decode
id := uuid.New().String()
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
return
}
// Check if method is present
if baseMessage.Method == "" {
err = fmt.Errorf("method not found")
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil))
return
}
// Check for JSON-RPC 2.0
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
err = fmt.Errorf("invalid json-rpc version")
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil))
return
}
// Check if message is a notification
if baseMessage.Id == nil {
id = ""
var notification mcp.JSONRPCNotification
if err = json.Unmarshal(body, &notification); err != nil {
err = fmt.Errorf("invalid notification request: %w", err)
s.logger.DebugContext(ctx, err.Error())
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.PARSE_ERROR, err.Error(), nil))
}
res, err := processMcpMessage(ctx, body, s, toolsetName)
// notifications will return empty string
if res == nil {
// Notifications do not expect a response
// Toolbox doesn't do anything with notifications yet
w.WriteHeader(http.StatusAccepted)
return
}
id = fmt.Sprintf("%s", baseMessage.Id)
method = baseMessage.Method
s.logger.DebugContext(ctx, fmt.Sprintf("method is: %s", method))
var res mcp.JSONRPCMessage
switch baseMessage.Method {
case "initialize":
var req mcp.InitializeRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp initialize request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
result := mcp.Initialize(s.version)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
case "tools/list":
var req mcp.ListToolsRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
toolset, ok := s.toolsets[toolsetName]
if !ok {
err = fmt.Errorf("toolset does not exist")
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
result := mcp.ToolsList(toolset)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
case "tools/call":
var req mcp.CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
toolName = req.Params.Name
toolArgument := req.Params.Arguments
s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := s.tools[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
break
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
break
}
var data map[string]any
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
break
}
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
// Since MCP doesn't support auth, an empty map will be use every time.
claimsFromAuth := make(map[string]map[string]any)
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
break
}
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
if !tool.Authorized([]string{}) {
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
break
}
result := mcp.ToolCall(ctx, tool, params)
res = mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}
default:
err = fmt.Errorf("invalid method %s", baseMessage.Method)
if err != nil {
s.logger.DebugContext(ctx, err.Error())
res = newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil)
}
// retrieve sse session
sessionId := r.URL.Query().Get("sessionId")
session, ok := s.sseManager.get(sessionId)
if !ok {
s.logger.DebugContext(ctx, "sse session not available")
@@ -387,6 +344,133 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, res)
}
// processMcpMessage process the messages received from clients
func processMcpMessage(ctx context.Context, body []byte, s *Server, toolsetName string) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return newJSONRPCError("", mcp.INTERNAL_ERROR, err.Error(), nil), err
}
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
var baseMessage struct {
Jsonrpc string `json:"jsonrpc"`
Method string `json:"method"`
Id mcp.RequestId `json:"id,omitempty"`
}
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
// Generate a new uuid if unable to decode
id := uuid.New().String()
return newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil), err
}
// Check if method is present
if baseMessage.Method == "" {
err = fmt.Errorf("method not found")
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("method is: %s", baseMessage.Method))
// Check for JSON-RPC 2.0
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
err = fmt.Errorf("invalid json-rpc version")
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
// Check if message is a notification
if baseMessage.Id == nil {
var notification mcp.JSONRPCNotification
if err = json.Unmarshal(body, &notification); err != nil {
err = fmt.Errorf("invalid notification request: %w", err)
return nil, err
}
return nil, nil
}
switch baseMessage.Method {
case "initialize":
var req mcp.InitializeRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp initialize request: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
result := mcp.Initialize(s.version)
return mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}, nil
case "tools/list":
var req mcp.ListToolsRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
toolset, ok := s.toolsets[toolsetName]
if !ok {
err = fmt.Errorf("toolset does not exist")
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
result := mcp.ToolsList(toolset)
return mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}, nil
case "tools/call":
var req mcp.CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := s.tools[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
}
var data map[string]any
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
}
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
// Since MCP doesn't support auth, an empty map will be use every time.
claimsFromAuth := make(map[string]map[string]any)
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
if !tool.Authorized([]string{}) {
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
}
result := mcp.ToolCall(ctx, tool, params)
return mcp.JSONRPCResponse{
Jsonrpc: mcp.JSONRPC_VERSION,
Id: baseMessage.Id,
Result: result,
}, nil
default:
err = fmt.Errorf("invalid method %s", baseMessage.Method)
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
}
}
// newJSONRPCError is the response sent back when an error has been encountered in mcp.
func newJSONRPCError(id mcp.RequestId, code int, message string, data any) mcp.JSONRPCError {
return mcp.JSONRPCError{

View File

@@ -15,16 +15,22 @@
package server
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"sync"
"testing"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
"github.com/googleapis/genai-toolbox/internal/telemetry"
)
const jsonrpcVersion = "2.0"
@@ -379,3 +385,78 @@ func runSseRequest(ts *httptest.Server, path string, proto string) (*http.Respon
}
return resp, nil
}
func TestStdioSession(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mockTools := []MockTool{tool1, tool2, tool3}
toolsMap, toolsets := setUpResources(t, mockTools)
pr, pw, err := os.Pipe()
if err != nil {
t.Fatalf("error with Pipe: %s", err)
}
testLogger, err := log.NewStdLogger(pw, os.Stderr, "warn")
if err != nil {
t.Fatalf("unable to initialize logger: %s", err)
}
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
if err != nil {
t.Fatalf("unable to setup otel: %s", err)
}
defer func() {
err := otelShutdown(ctx)
if err != nil {
t.Fatalf("error shutting down OpenTelemetry: %s", err)
}
}()
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
if err != nil {
t.Fatalf("unable to create custom metrics: %s", err)
}
sseManager := &sseManager{
mu: sync.RWMutex{},
sseSessions: make(map[string]*sseSession),
}
server := &Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: toolsMap, toolsets: toolsets}
in := bufio.NewReader(pr)
stdioSession := NewStdioSession(server, in, pw)
// test stdioSession.readLine()
input := "test readLine function\n"
_, err = fmt.Fprintf(pw, "%s", input)
if err != nil {
t.Fatalf("error writing into pipe w: %s", err)
}
line, err := stdioSession.readLine(ctx)
if err != nil {
t.Fatalf("error with stdioSession.readLine: %s", err)
}
if line != input {
t.Fatalf("unexpected line: got %s, want %s", line, input)
}
// test stdioSession.write()
write := "test write function"
err = stdioSession.write(ctx, write)
if err != nil {
t.Fatalf("error with stdioSession.write: %s", err)
}
read, err := in.ReadString('\n')
if err != nil {
t.Fatalf("error reading: %s", err)
}
want := fmt.Sprintf(`"%s"`, write) + "\n"
if read != want {
t.Fatalf("unexpected read: got %s, want %s", read, want)
}
}

View File

@@ -17,6 +17,7 @@ package server
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
@@ -265,6 +266,12 @@ func (s *Server) Serve(ctx context.Context) error {
return s.srv.Serve(s.listener)
}
// ServeStdio starts a new stdio session for mcp.
func (s *Server) ServeStdio(ctx context.Context, stdin io.Reader, stdout io.Writer) error {
stdioServer := NewStdioSession(s, stdin, stdout)
return stdioServer.Start(ctx)
}
// Shutdown gracefully shuts down the server without interrupting any active
// connections. It uses http.Server.Shutdown() and has the same functionality.
func (s *Server) Shutdown(ctx context.Context) error {