mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add toolset feature to mcp (#425)
Update MCP server to support toolset.
User can now connect to specific toolset's sse via
`/mcp/{toolset_name}/sse` url, or POST to `/mcp/{toolset_name}`. If
toolset_name is not provided, it will list all tools by default.
Fixes #403
This commit is contained in:
@@ -76,6 +76,11 @@ func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
|
||||
|
||||
r.Route("/{toolsetName}", func(r chi.Router) {
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { mcpHandler(s, w, r) })
|
||||
})
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -85,7 +90,10 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
sessionId := uuid.New().String()
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||
span.SetAttributes(attribute.String("session_id", sessionId))
|
||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
@@ -105,6 +113,7 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.instrumentation.McpSse.Add(
|
||||
r.Context(),
|
||||
1,
|
||||
metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)),
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
@@ -128,15 +137,20 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// https scheme formatting if (forwarded) request is a TLS request
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
if (proto == "") {
|
||||
if r.TLS == nil {
|
||||
if proto == "" {
|
||||
if r.TLS == nil {
|
||||
proto = "http"
|
||||
} else {
|
||||
proto = "https"
|
||||
}
|
||||
}
|
||||
|
||||
// send initial endpoint event
|
||||
messageEndpoint := fmt.Sprintf("%s://%s/mcp?sessionId=%s", proto, r.Host, sessionId)
|
||||
toolsetURL := ""
|
||||
if toolsetName != "" {
|
||||
toolsetURL = fmt.Sprintf("/%s", toolsetName)
|
||||
}
|
||||
messageEndpoint := fmt.Sprintf("%s://%s/mcp%s?sessionId=%s", proto, r.Host, toolsetURL, sessionId)
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint))
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||
flusher.Flush()
|
||||
@@ -163,6 +177,10 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||
|
||||
var id, toolName, method string
|
||||
var err error
|
||||
defer func() {
|
||||
@@ -179,7 +197,8 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
r.Context(),
|
||||
1,
|
||||
metric.WithAttributes(attribute.String("toolbox.sse.sessionId", id)),
|
||||
metric.WithAttributes(attribute.String("toolbox.name", toolName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.tool.name", toolName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)),
|
||||
metric.WithAttributes(attribute.String("toolbox.method", method)),
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
@@ -266,7 +285,7 @@ func mcpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
res = newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil)
|
||||
break
|
||||
}
|
||||
toolset, ok := s.toolsets[""]
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
|
||||
@@ -17,6 +17,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
@@ -64,12 +65,14 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
isErr bool
|
||||
body mcp.JSONRPCRequest
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "initialize",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "mcp-initialize",
|
||||
@@ -91,6 +94,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "basic notification",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Request: mcp.Request{
|
||||
@@ -100,6 +104,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "tools/list",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list",
|
||||
@@ -129,8 +134,52 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on tool1_only",
|
||||
url: "/tool1_only",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-tool1",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-tool1",
|
||||
"result": map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "no_params",
|
||||
"inputSchema": tool1InputSchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on invalid tool set",
|
||||
url: "/foo",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-invalid-toolset",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-invalid-toolset",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "toolset does not exist",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
@@ -148,6 +197,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
@@ -167,6 +217,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid jsonrpc version",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: "1.0",
|
||||
@@ -192,7 +243,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(reqMarshal))
|
||||
resp, body, err := runRequest(ts, http.MethodPost, tc.url, bytes.NewBuffer(reqMarshal))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
@@ -223,36 +274,54 @@ func TestSseEndpoint(t *testing.T) {
|
||||
cacheControl := "no-cache"
|
||||
connection := "keep-alive"
|
||||
accessControlAllowOrigin := "*"
|
||||
wantEvent := "event: endpoint"
|
||||
|
||||
t.Run("test sse endpoint", func(t *testing.T) {
|
||||
resp, err := http.Get(ts.URL + "/sse")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
event string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
url: "/sse",
|
||||
event: fmt.Sprintf("event: endpoint\ndata: %s/mcp?sessionId=", ts.URL),
|
||||
},
|
||||
{
|
||||
name: "toolset1",
|
||||
url: "/tool1_only/sse",
|
||||
event: fmt.Sprintf("event: endpoint\ndata: %s/mcp/tool1_only?sessionId=", ts.URL),
|
||||
},
|
||||
}
|
||||
|
||||
if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType)
|
||||
}
|
||||
if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl)
|
||||
}
|
||||
if gotConnection := resp.Header.Get("Connection"); gotConnection != connection {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection)
|
||||
}
|
||||
if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin)
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(ts.URL + tc.url)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := resp.Body.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read response: %s", err)
|
||||
}
|
||||
endpointEvent := string(buffer[:n])
|
||||
if !strings.Contains(endpointEvent, wantEvent) {
|
||||
t.Fatalf("unexpected event: got %s", endpointEvent)
|
||||
}
|
||||
})
|
||||
if gotContentType := resp.Header.Get("Content-type"); gotContentType != contentType {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", contentType, gotContentType)
|
||||
}
|
||||
if gotCacheControl := resp.Header.Get("Cache-Control"); gotCacheControl != cacheControl {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", cacheControl, gotCacheControl)
|
||||
}
|
||||
if gotConnection := resp.Header.Get("Connection"); gotConnection != connection {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", connection, gotConnection)
|
||||
}
|
||||
if gotAccessControlAllowOrigin := resp.Header.Get("Access-Control-Allow-Origin"); gotAccessControlAllowOrigin != accessControlAllowOrigin {
|
||||
t.Fatalf("unexpected cache-control header: want %s, got %s", accessControlAllowOrigin, gotAccessControlAllowOrigin)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := resp.Body.Read(buffer)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read response: %s", err)
|
||||
}
|
||||
endpointEvent := string(buffer[:n])
|
||||
if !strings.Contains(endpointEvent, tc.event) {
|
||||
t.Fatalf("unexpected event: got %s, want to contain %s", endpointEvent, tc.event)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user