From e307857085ac4c8c2ee8292c914daa5534ba74bf Mon Sep 17 00:00:00 2001 From: Yuan <45984206+Yuan325@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:11:53 -0700 Subject: [PATCH] feat: add toolset feature to mcp (#425) Update MCP server to support toolset. User can now connect to specific toolset's sse via `/mcp/{toolset_name}/sse` url, or POST to `/mcp/{toolset_name}`. If toolset_name is not provided, it will list all tools by default. Fixes #403 --- internal/server/mcp.go | 29 ++++++-- internal/server/mcp_test.go | 129 +++++++++++++++++++++++++++--------- 2 files changed, 123 insertions(+), 35 deletions(-) diff --git a/internal/server/mcp.go b/internal/server/mcp.go index ab22760a8a..7fda6b076c 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -76,6 +76,11 @@ func mcpRouter(s *Server) (chi.Router, error) { r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) }) + r.Route("/{toolsetName}", func(r chi.Router) { + r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) + r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) }) + }) + return r, nil } @@ -85,7 +90,10 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { r = r.WithContext(ctx) sessionId := uuid.New().String() + toolsetName := chi.URLParam(r, "toolsetName") + s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("session_id", sessionId)) + span.SetAttributes(attribute.String("toolset_name", toolsetName)) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -105,6 +113,7 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { s.instrumentation.McpSse.Add( r.Context(), 1, + metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)), metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) @@ -128,15 +137,20 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { // https scheme formatting if (forwarded) request is a TLS request proto := r.Header.Get("X-Forwarded-Proto") - if (proto == "") { - if r.TLS == nil { + if proto == "" { + if r.TLS == nil { proto = "http" } else { proto = "https" } } + // send initial endpoint event - messageEndpoint := fmt.Sprintf("%s://%s/mcp?sessionId=%s", proto, r.Host, sessionId) + toolsetURL := "" + if toolsetName != "" { + toolsetURL = fmt.Sprintf("/%s", toolsetName) + } + messageEndpoint := fmt.Sprintf("%s://%s/mcp%s?sessionId=%s", proto, r.Host, toolsetURL, sessionId) s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint)) fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint) flusher.Flush() @@ -163,6 +177,10 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp") r = r.WithContext(ctx) + toolsetName := chi.URLParam(r, "toolsetName") + s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) + span.SetAttributes(attribute.String("toolset_name", toolsetName)) + var id, toolName, method string var err error defer func() { @@ -179,7 +197,8 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) { r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)), - metric.WithAttributes(attribute.String("toolbox.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.tool.name", toolName)), + metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)), metric.WithAttributes(attribute.String("toolbox.method", method)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) @@ -266,7 +285,7 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) { res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil) break } - toolset, ok := s.toolsets[""] + toolset, ok := s.toolsets[toolsetName] if !ok { err = fmt.Errorf("toolset does not exist") s.logger.DebugContext(ctx, err.Error()) diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index b5f422c65e..a054231d01 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -17,6 +17,7 @@ package server import ( "bytes" "encoding/json" + "fmt" "net/http" "reflect" "strings" @@ -64,12 +65,14 @@ func TestMcpEndpoint(t *testing.T) { testCases := []struct { name string + url string isErr bool body mcp.JSONRPCRequest want map[string]any }{ { name: "initialize", + url: "/", body: mcp.JSONRPCRequest{ Jsonrpc: jsonrpcVersion, Id: "mcp-initialize", @@ -91,6 +94,7 @@ func TestMcpEndpoint(t *testing.T) { }, { name: "basic notification", + url: "/", body: mcp.JSONRPCRequest{ Jsonrpc: jsonrpcVersion, Request: mcp.Request{ @@ -100,6 +104,7 @@ func TestMcpEndpoint(t *testing.T) { }, { name: "tools/list", + url: "/", body: mcp.JSONRPCRequest{ Jsonrpc: jsonrpcVersion, Id: "tools-list", @@ -129,8 +134,52 @@ func TestMcpEndpoint(t *testing.T) { }, }, }, + { + name: "tools/list on tool1_only", + url: "/tool1_only", + body: mcp.JSONRPCRequest{ + Jsonrpc: jsonrpcVersion, + Id: "tools-list-tool1", + Request: mcp.Request{ + Method: "tools/list", + }, + }, + want: map[string]any{ + "jsonrpc": "2.0", + "id": "tools-list-tool1", + "result": map[string]any{ + "tools": []any{ + map[string]any{ + "name": "no_params", + "inputSchema": tool1InputSchema, + }, + }, + }, + }, + }, + { + name: "tools/list on invalid tool set", + url: "/foo", + isErr: true, + body: mcp.JSONRPCRequest{ + Jsonrpc: jsonrpcVersion, + Id: "tools-list-invalid-toolset", + Request: mcp.Request{ + Method: "tools/list", + }, + }, + want: map[string]any{ + "jsonrpc": "2.0", + "id": "tools-list-invalid-toolset", + "error": map[string]any{ + "code": -32600.0, + "message": "toolset does not exist", + }, + }, + }, { name: "missing method", + url: "/", isErr: true, body: mcp.JSONRPCRequest{ Jsonrpc: jsonrpcVersion, @@ -148,6 +197,7 @@ func TestMcpEndpoint(t *testing.T) { }, { name: "invalid method", + url: "/", isErr: true, body: mcp.JSONRPCRequest{ Jsonrpc: jsonrpcVersion, @@ -167,6 +217,7 @@ func TestMcpEndpoint(t *testing.T) { }, { name: "invalid jsonrpc version", + url: "/", isErr: true, body: mcp.JSONRPCRequest{ Jsonrpc: "1.0", @@ -192,7 +243,7 @@ func TestMcpEndpoint(t *testing.T) { t.Fatalf("unexpected error during marshaling of body") } - resp, body, err := runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(reqMarshal)) + resp, body, err := runRequest(ts, http.MethodPost, tc.url, bytes.NewBuffer(reqMarshal)) if err != nil { t.Fatalf("unexpected error during request: %s", err) } @@ -223,36 +274,54 @@ func TestSseEndpoint(t *testing.T) { cacheControl := "no-cache" connection := "keep-alive" accessControlAllowOrigin := "*" - wantEvent := "event: endpoint" - t.Run("test sse endpoint", func(t *testing.T) { - resp, err := http.Get(ts.URL + "/sse") - if err != nil { - t.Fatalf("unexpected error during request: %s", err) - } - defer resp.Body.Close() + testCases := []struct { + name string + url string + event string + }{ + { + name: "basic", + url: "/sse", + event: fmt.Sprintf("event: endpoint\ndata: %s/mcp?sessionId=", ts.URL), + }, + { + name: "toolset1", + url: "/tool1_only/sse", + event: fmt.Sprintf("event: endpoint\ndata: %s/mcp/tool1_only?sessionId=", ts.URL), + }, + } - if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType { - t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType) - } - if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl { - t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl) - } - if gotConnection := resp.Header.Get("Connection"); gotConnection != connection { - t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection) - } - if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin { - t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Get(ts.URL + tc.url) + if err != nil { + t.Fatalf("unexpected error during request: %s", err) + } + defer resp.Body.Close() - buffer := make([]byte, 1024) - n, err := resp.Body.Read(buffer) - if err != nil { - t.Fatalf("unable to read response: %s", err) - } - endpointEvent := string(buffer[:n]) - if !strings.Contains(endpointEvent, wantEvent) { - t.Fatalf("unexpected event: got %s", endpointEvent) - } - }) + if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType { + t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType) + } + if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl { + t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl) + } + if gotConnection := resp.Header.Get("Connection"); gotConnection != connection { + t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection) + } + if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin { + t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin) + } + + buffer := make([]byte, 1024) + n, err := resp.Body.Read(buffer) + if err != nil { + t.Fatalf("unable to read response: %s", err) + } + endpointEvent := string(buffer[:n]) + if !strings.Contains(endpointEvent, tc.event) { + t.Fatalf("unexpected event: got %s, want to contain %s", endpointEvent, tc.event) + } + }) + } }