mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: support MCP stdio transport protocol (#607)
Support MCP [stdio](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio) transport protocol! To run stdio with Toolbox, user have to use the `--stdio` flag. Example of running MCP Toolbox with MCP Inspector via stdio transport protocol: `npx @modelcontextprotocol/inspector ./toolbox --stdio`. --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
38
cmd/root.go
38
cmd/root.go
@@ -71,12 +71,14 @@ type Command struct {
|
||||
cfg server.ServerConfig
|
||||
logger log.Logger
|
||||
tools_file string
|
||||
inStream io.Reader
|
||||
outStream io.Writer
|
||||
errStream io.Writer
|
||||
}
|
||||
|
||||
// NewCommand returns a Command object representing an invocation of the CLI.
|
||||
func NewCommand(opts ...Option) *Command {
|
||||
in := os.Stdin
|
||||
out := os.Stdout
|
||||
err := os.Stderr
|
||||
|
||||
@@ -87,6 +89,7 @@ func NewCommand(opts ...Option) *Command {
|
||||
}
|
||||
cmd := &Command{
|
||||
Command: baseCmd,
|
||||
inStream: in,
|
||||
outStream: out,
|
||||
errStream: err,
|
||||
}
|
||||
@@ -98,7 +101,8 @@ func NewCommand(opts ...Option) *Command {
|
||||
// Set server version
|
||||
cmd.cfg.Version = versionString
|
||||
|
||||
// set baseCmd out and err the same as cmd.
|
||||
// set baseCmd in, out and err the same as cmd.
|
||||
baseCmd.SetIn(cmd.inStream)
|
||||
baseCmd.SetOut(cmd.outStream)
|
||||
baseCmd.SetErr(cmd.errStream)
|
||||
|
||||
@@ -115,6 +119,7 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
|
||||
flags.StringVar(&cmd.cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')")
|
||||
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
@@ -163,7 +168,25 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
return toolsFile, nil
|
||||
}
|
||||
|
||||
// updateLogLevel checks if Toolbox have to update the existing log level set by users.
|
||||
// stdio doesn't support "debug" and "info" logs.
|
||||
func updateLogLevel(stdio bool, logLevel string) bool {
|
||||
if stdio {
|
||||
switch strings.ToUpper(logLevel) {
|
||||
case log.Debug, log.Info:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func run(cmd *Command) error {
|
||||
if updateLogLevel(cmd.cfg.Stdio, cmd.cfg.LogLevel.String()) {
|
||||
cmd.cfg.LogLevel = server.StringLevel(log.Warn)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -262,9 +285,16 @@ func run(cmd *Command) error {
|
||||
srvErr := make(chan error)
|
||||
go func() {
|
||||
defer close(srvErr)
|
||||
err = s.Serve(ctx)
|
||||
if err != nil {
|
||||
srvErr <- err
|
||||
if cmd.cfg.Stdio {
|
||||
err = s.ServeStdio(ctx, cmd.inStream, cmd.outStream)
|
||||
if err != nil {
|
||||
srvErr <- err
|
||||
}
|
||||
} else {
|
||||
err = s.Serve(ctx)
|
||||
if err != nil {
|
||||
srvErr <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -163,6 +163,13 @@ func TestServerConfigFlags(t *testing.T) {
|
||||
TelemetryServiceName: "toolbox-custom",
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "stdio",
|
||||
args: []string{"--stdio"},
|
||||
want: withDefaults(server.ServerConfig{
|
||||
Stdio: true,
|
||||
}),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -858,3 +865,51 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestUpdateLogLevel(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
stdio bool
|
||||
logLevel string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
desc: "no stdio",
|
||||
stdio: false,
|
||||
logLevel: "info",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "stdio with info log",
|
||||
stdio: true,
|
||||
logLevel: "info",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "stdio with debug log",
|
||||
stdio: true,
|
||||
logLevel: "debug",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "stdio with warn log",
|
||||
stdio: true,
|
||||
logLevel: "warn",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "stdio with error log",
|
||||
stdio: true,
|
||||
logLevel: "error",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := updateLogLevel(tc.stdio, tc.logLevel)
|
||||
if got != tc.want {
|
||||
t.Fatalf("incorrect indication to update log level: got %t, want %t", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,20 @@ MCP is only compatible with Toolbox version 0.3.0 and above.
|
||||
|
||||
1. [Set up](../getting-started/configure.md) your `tools.yaml` file.
|
||||
|
||||
### Connecting via Standard Input/Output (stdio)
|
||||
Toolbox supports the
|
||||
[stdio](https://modelcontextprotocol.io/docs/concepts/transports#standard-input%2Foutput-stdio)
|
||||
transport protocol. Users that wish to use stdio will have to include the
|
||||
`--stdio` flag when running Toolbox.
|
||||
|
||||
```bash
|
||||
./toolbox --stdio
|
||||
```
|
||||
|
||||
When running with stdio, Toolbox will listen via stdio instead of acting as a
|
||||
remote HTTP server. Logs will be set to the `warn` level by default. `debug` and `info` logs are not
|
||||
supported with stdio.
|
||||
|
||||
### Connecting via HTTP
|
||||
Toolbox supports the HTTP transport protocol with and without SSE.
|
||||
|
||||
@@ -61,8 +75,27 @@ If you would like to connect to a specific toolset, connect via `http://127.0.0.
|
||||
|
||||
### Using the MCP Inspector with Toolbox
|
||||
|
||||
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testing and debugging Toolbox server.
|
||||
Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for
|
||||
testing and debugging Toolbox server.
|
||||
|
||||
{{< tabpane text=true >}}
|
||||
{{% tab header="STDIO" lang="en" %}}
|
||||
1. Run Inspector with Toolbox as a subprocess:
|
||||
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector ./toolbox --stdio
|
||||
```
|
||||
|
||||
1. For `Transport Type` dropdown menu, select `STDIO`.
|
||||
|
||||
1. In `Command`, make sure that it is set to :`./toolbox` (or the correct path to where the Toolbox binary is installed).
|
||||
|
||||
1. In `Arguments`, make sure that it's filled with `--stdio`.
|
||||
|
||||
1. Click the `Connect` button. It might take awhile to spin up Toolbox. Voila!
|
||||
You should be able to inspect your toolbox tools!
|
||||
{{% /tab %}}
|
||||
{{% tab header="HTTP with SSE" lang="en" %}}
|
||||
1. [Run Toolbox](../getting-started/introduction/_index.md#running-the-server).
|
||||
|
||||
1. In a separate terminal, run Inspector directly through `npx`:
|
||||
@@ -78,12 +111,13 @@ Use MCP [Inspector](https://github.com/modelcontextprotocol/inspector) for testi
|
||||
|
||||
1. Click the `Connect` button. Voila! You should be able to inspect your toolbox
|
||||
tools!
|
||||
{{% /tab %}} {{< /tabpane >}}
|
||||
|
||||
### Tested Clients
|
||||
|
||||
| Client | SSE Works | MCP Config Docs |
|
||||
|--------|--------|--------|
|
||||
| Claude Desktop | ❗ | Claude Desktop only supports STDIO -- use [`mcp-remote`](https://www.npmjs.com/package/mcp-remote) to proxy. |
|
||||
| Claude Desktop | ✅ | https://modelcontextprotocol.io/quickstart/user#1-download-claude-for-desktop |
|
||||
| MCP Inspector | ✅ | https://github.com/modelcontextprotocol/inspector |
|
||||
| Cursor | ✅ | https://docs.cursor.com/context/model-context-protocol |
|
||||
| Windsurf | ✅ | https://docs.windsurf.com/windsurf/mcp |
|
||||
|
||||
@@ -82,6 +82,8 @@ type ServerConfig struct {
|
||||
TelemetryOTLP string
|
||||
// TelemetryServiceName defines the value of service.name resource attribute.
|
||||
TelemetryServiceName string
|
||||
// Stdio indicates if Toolbox is listening via MCP stdio.
|
||||
Stdio bool
|
||||
}
|
||||
|
||||
type logFormat string
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -27,6 +29,7 @@ import (
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
@@ -65,6 +68,101 @@ func (m *sseManager) remove(id string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
type stdioSession struct {
|
||||
server *Server
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
|
||||
stdioSession := &stdioSession{
|
||||
server: s,
|
||||
reader: bufio.NewReader(stdin),
|
||||
writer: stdout,
|
||||
}
|
||||
return stdioSession
|
||||
}
|
||||
|
||||
func (s *stdioSession) Start(ctx context.Context) error {
|
||||
return s.readInputStream(ctx)
|
||||
}
|
||||
|
||||
// readInputStream reads requests/notifications from MCP clients through stdin
|
||||
func (s *stdioSession) readInputStream(ctx context.Context) error {
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
line, err := s.readLine(ctx)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
res, err := processMcpMessage(ctx, []byte(line), s.server, "")
|
||||
if err != nil {
|
||||
// errors during the processing of message will generate a valid MCP Error response.
|
||||
// server can continue to run.
|
||||
s.server.logger.ErrorContext(ctx, err.Error())
|
||||
}
|
||||
if err = s.write(ctx, res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readLine process each line within the input stream.
|
||||
func (s *stdioSession) readLine(ctx context.Context) (string, error) {
|
||||
readChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
defer close(readChan)
|
||||
defer close(errChan)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
line, err := s.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
select {
|
||||
case errChan <- err:
|
||||
case <-done:
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case readChan <- line:
|
||||
case <-done:
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
// if context is cancelled, return an empty string
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
// return error if error is found
|
||||
case err := <-errChan:
|
||||
return "", err
|
||||
// return line if successful
|
||||
case line := <-readChan:
|
||||
return line, nil
|
||||
}
|
||||
}
|
||||
|
||||
// write writes to stdout with response to client
|
||||
func (s *stdioSession) write(ctx context.Context, response any) error {
|
||||
res, _ := json.Marshal(response)
|
||||
|
||||
_, err := fmt.Fprintf(s.writer, "%s\n", res)
|
||||
return err
|
||||
}
|
||||
|
||||
// mcpRouter creates a router that represents the routes under /mcp
|
||||
func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
@@ -74,11 +172,11 @@ func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
|
||||
|
||||
r.Route("/{toolsetName}", func(r chi.Router) {
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
|
||||
})
|
||||
|
||||
return r, nil
|
||||
@@ -172,16 +270,19 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// mcpHandler handles all mcp messages.
|
||||
func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
// httpHandler handles all mcp messages.
|
||||
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
||||
r = r.WithContext(ctx)
|
||||
ctx = util.WithLogger(r.Context(), s.logger)
|
||||
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||
|
||||
var id, toolName, method string
|
||||
// retrieve sse session id, if applicable
|
||||
sessionId := r.URL.Query().Get("sessionId")
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -196,10 +297,7 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.instrumentation.McpPost.Add(
|
||||
r.Context(),
|
||||
1,
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)),
|
||||
metric.WithAttributes(attribute.String("toolbox.tool.name", toolName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.method", method)),
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
}()
|
||||
@@ -207,166 +305,25 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
// Read and returns a body from io.Reader
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id = uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
}
|
||||
|
||||
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
|
||||
var baseMessage struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Id mcp.RequestId `json:"id,omitempty"`
|
||||
}
|
||||
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if method is present
|
||||
if baseMessage.Method == "" {
|
||||
err = fmt.Errorf("method not found")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check for JSON-RPC 2.0
|
||||
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
|
||||
err = fmt.Errorf("invalid json-rpc version")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if message is a notification
|
||||
if baseMessage.Id == nil {
|
||||
id = ""
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err = json.Unmarshal(body, ¬ification); err != nil {
|
||||
err = fmt.Errorf("invalid notification request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(baseMessage.Id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
}
|
||||
res, err := processMcpMessage(ctx, body, s, toolsetName)
|
||||
// notifications will return empty string
|
||||
if res == nil {
|
||||
// Notifications do not expect a response
|
||||
// Toolbox doesn't do anything with notifications yet
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
id = fmt.Sprintf("%s", baseMessage.Id)
|
||||
method = baseMessage.Method
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("method is: %s", method))
|
||||
|
||||
var res mcp.JSONRPCMessage
|
||||
switch baseMessage.Method {
|
||||
case "initialize":
|
||||
var req mcp.InitializeRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp initialize request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
result := mcp.Initialize(s.version)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
case "tools/list":
|
||||
var req mcp.ListToolsRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
result := mcp.ToolsList(toolset)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
case "tools/call":
|
||||
var req mcp.CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
toolName = req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
var data map[string]any
|
||||
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
|
||||
result := mcp.ToolCall(ctx, tool, params)
|
||||
res = mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("invalid method %s", baseMessage.Method)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil)
|
||||
}
|
||||
|
||||
// retrieve sse session
|
||||
sessionId := r.URL.Query().Get("sessionId")
|
||||
session, ok := s.sseManager.get(sessionId)
|
||||
if !ok {
|
||||
s.logger.DebugContext(ctx, "sse session not available")
|
||||
@@ -387,6 +344,133 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, r, res)
|
||||
}
|
||||
|
||||
// processMcpMessage process the messages received from clients
|
||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, toolsetName string) (any, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return newJSONRPCError("", mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
|
||||
var baseMessage struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Id mcp.RequestId `json:"id,omitempty"`
|
||||
}
|
||||
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
return newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if method is present
|
||||
if baseMessage.Method == "" {
|
||||
err = fmt.Errorf("method not found")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("method is: %s", baseMessage.Method))
|
||||
|
||||
// Check for JSON-RPC 2.0
|
||||
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
|
||||
err = fmt.Errorf("invalid json-rpc version")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if message is a notification
|
||||
if baseMessage.Id == nil {
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err = json.Unmarshal(body, ¬ification); err != nil {
|
||||
err = fmt.Errorf("invalid notification request: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch baseMessage.Method {
|
||||
case "initialize":
|
||||
var req mcp.InitializeRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp initialize request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
result := mcp.Initialize(s.version)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
case "tools/list":
|
||||
var req mcp.ListToolsRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
result := mcp.ToolsList(toolset)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
case "tools/call":
|
||||
var req mcp.CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var data map[string]any
|
||||
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := mcp.ToolCall(ctx, tool, params)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
default:
|
||||
err = fmt.Errorf("invalid method %s", baseMessage.Method)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
// newJSONRPCError is the response sent back when an error has been encountered in mcp.
|
||||
func newJSONRPCError(id mcp.RequestId, code int, message string, data any) mcp.JSONRPCError {
|
||||
return mcp.JSONRPCError{
|
||||
|
||||
@@ -15,16 +15,22 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
)
|
||||
|
||||
const jsonrpcVersion = "2.0"
|
||||
@@ -379,3 +385,78 @@ func runSseRequest(ts *httptest.Server, path string, proto string) (*http.Respon
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func TestStdioSession(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
mockTools := []MockTool{tool1, tool2, tool3}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
|
||||
pr, pw, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("error with Pipe: %s", err)
|
||||
}
|
||||
|
||||
testLogger, err := log.NewStdLogger(pw, os.Stderr, "warn")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize logger: %s", err)
|
||||
}
|
||||
|
||||
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to setup otel: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := otelShutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("error shutting down OpenTelemetry: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
|
||||
server := &Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: toolsMap, toolsets: toolsets}
|
||||
|
||||
in := bufio.NewReader(pr)
|
||||
stdioSession := NewStdioSession(server, in, pw)
|
||||
|
||||
// test stdioSession.readLine()
|
||||
input := "test readLine function\n"
|
||||
_, err = fmt.Fprintf(pw, "%s", input)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing into pipe w: %s", err)
|
||||
}
|
||||
|
||||
line, err := stdioSession.readLine(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("error with stdioSession.readLine: %s", err)
|
||||
}
|
||||
if line != input {
|
||||
t.Fatalf("unexpected line: got %s, want %s", line, input)
|
||||
}
|
||||
|
||||
// test stdioSession.write()
|
||||
write := "test write function"
|
||||
err = stdioSession.write(ctx, write)
|
||||
if err != nil {
|
||||
t.Fatalf("error with stdioSession.write: %s", err)
|
||||
}
|
||||
|
||||
read, err := in.ReadString('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("error reading: %s", err)
|
||||
}
|
||||
want := fmt.Sprintf(`"%s"`, write) + "\n"
|
||||
if read != want {
|
||||
t.Fatalf("unexpected read: got %s, want %s", read, want)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -265,6 +266,12 @@ func (s *Server) Serve(ctx context.Context) error {
|
||||
return s.srv.Serve(s.listener)
|
||||
}
|
||||
|
||||
// ServeStdio starts a new stdio session for mcp.
|
||||
func (s *Server) ServeStdio(ctx context.Context, stdin io.Reader, stdout io.Writer) error {
|
||||
stdioServer := NewStdioSession(s, stdin, stdout)
|
||||
return stdioServer.Start(ctx)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server without interrupting any active
|
||||
// connections. It uses http.Server.Shutdown() and has the same functionality.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
|
||||
Reference in New Issue
Block a user