mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat: support MCP version 2025-03-26 (#755)
This feature includes the following: * Implement initialize lifecycle (including version negotiation) * Add the v20250326 schema * Supporting the `DELETE` and `GET` endpoint for MCP. * Supporting streamable HTTP (without SSE). * Terminating sessions after timeout (default = 10 minutes from last active). * Toolbox do not support batch request. Will response with `Invalid requests` if batch requests is received.
This commit is contained in:
@@ -17,13 +17,13 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
@@ -199,7 +199,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.DebugContext(ctx, "tool invocation authorized")
|
||||
|
||||
var data map[string]any
|
||||
if err = decodeJSON(r.Body, &data); err != nil {
|
||||
if err = util.DecodeJSON(r.Body, &data); err != nil {
|
||||
render.Status(r, http.StatusBadRequest)
|
||||
err = fmt.Errorf("request body was invalid JSON: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -274,13 +274,3 @@ func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, e.HTTPStatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeJSON decodes a given reader into an interface using the json decoder.
|
||||
func decodeJSON(r io.Reader, v interface{}) error {
|
||||
defer io.Copy(io.Discard, r) //nolint:errcheck
|
||||
d := json.NewDecoder(r)
|
||||
// specify JSON numbers should get parsed to json.Number instead of float64 by default.
|
||||
// This prevents loss between floats/ints.
|
||||
d.UseNumber()
|
||||
return d.Decode(v)
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestToolsetEndpoint(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
|
||||
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
@@ -174,7 +174,7 @@ func TestToolGetEndpoint(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil)
|
||||
resp, body, err := runRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
@@ -251,7 +251,7 @@ func TestToolInvokeEndpoint(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body, err := runRequest(ts, http.MethodPost, fmt.Sprintf("/tool/%s/invoke", tc.toolName), tc.requestBody)
|
||||
resp, body, err := runRequest(ts, http.MethodPost, fmt.Sprintf("/tool/%s/invoke", tc.toolName), tc.requestBody, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -153,10 +152,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: tools, toolsets: toolsets}
|
||||
var r chi.Router
|
||||
@@ -197,12 +193,17 @@ func runServer(r chi.Router, tls bool) *httptest.Server {
|
||||
return ts
|
||||
}
|
||||
|
||||
func runRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
|
||||
func runRequest(ts *httptest.Server, method, path string, body io.Reader, header map[string]string) (*http.Response, []byte, error) {
|
||||
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for k, v := range header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to send request: %w", err)
|
||||
|
||||
@@ -23,12 +23,17 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util"
|
||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
@@ -36,30 +41,41 @@ import (
|
||||
)
|
||||
|
||||
type sseSession struct {
|
||||
sessionId string
|
||||
writer http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
done chan struct{}
|
||||
eventQueue chan string
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
// sseManager manages and control access to sse sessions
|
||||
type sseManager struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
sseSessions map[string]*sseSession
|
||||
}
|
||||
|
||||
func (m *sseManager) get(id string) (*sseSession, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
session, ok := m.sseSessions[id]
|
||||
session.lastActive = time.Now()
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func newSseManager(ctx context.Context) *sseManager {
|
||||
sseM := &sseManager{
|
||||
mu: sync.Mutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
go sseM.cleanupRoutine(ctx)
|
||||
return sseM
|
||||
}
|
||||
|
||||
func (m *sseManager) add(id string, session *sseSession) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.sseSessions[id] = session
|
||||
m.mu.Unlock()
|
||||
session.lastActive = time.Now()
|
||||
}
|
||||
|
||||
func (m *sseManager) remove(id string) {
|
||||
@@ -68,10 +84,35 @@ func (m *sseManager) remove(id string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *sseManager) cleanupRoutine(ctx context.Context) {
|
||||
timeout := 10 * time.Minute
|
||||
ticker := time.NewTicker(timeout)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
func() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
now := time.Now()
|
||||
for id, sess := range m.sseSessions {
|
||||
if now.Sub(sess.lastActive) > timeout {
|
||||
delete(m.sseSessions, id)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type stdioSession struct {
|
||||
server *Server
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
protocol string
|
||||
server *Server
|
||||
reader *bufio.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
|
||||
@@ -100,13 +141,15 @@ func (s *stdioSession) readInputStream(ctx context.Context) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
res, err := processMcpMessage(ctx, []byte(line), s.server, "")
|
||||
v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "")
|
||||
if err != nil {
|
||||
// errors during the processing of message will generate a valid MCP Error response.
|
||||
// server can continue to run.
|
||||
s.server.logger.ErrorContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
if v != "" {
|
||||
s.protocol = v
|
||||
}
|
||||
// no responses for notifications
|
||||
if res != nil {
|
||||
if err = s.write(ctx, res); err != nil {
|
||||
@@ -176,11 +219,15 @@ func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
|
||||
r.Delete("/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
r.Route("/{toolsetName}", func(r chi.Router) {
|
||||
r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) })
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) })
|
||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) })
|
||||
r.Delete("/", func(w http.ResponseWriter, r *http.Request) {})
|
||||
})
|
||||
|
||||
return r, nil
|
||||
@@ -228,7 +275,6 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
}
|
||||
session := &sseSession{
|
||||
sessionId: sessionId,
|
||||
writer: w,
|
||||
flusher: flusher,
|
||||
done: make(chan struct{}),
|
||||
@@ -274,19 +320,47 @@ func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// methodNotAllowed handles all mcp messages.
|
||||
func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
err := fmt.Errorf("toolbox does not support streaming in streamable HTTP transport")
|
||||
s.logger.DebugContext(r.Context(), err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusMethodNotAllowed))
|
||||
}
|
||||
|
||||
// httpHandler handles all mcp messages.
|
||||
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
||||
r = r.WithContext(ctx)
|
||||
ctx = util.WithLogger(r.Context(), s.logger)
|
||||
|
||||
var sessionId, protocolVersion string
|
||||
var session *sseSession
|
||||
|
||||
// check if client connects via sse
|
||||
// v2024-11-05 supports http with sse
|
||||
paramSessionId := r.URL.Query().Get("sessionId")
|
||||
if paramSessionId != "" {
|
||||
sessionId = paramSessionId
|
||||
protocolVersion = v20241105.PROTOCOL_VERSION
|
||||
var ok bool
|
||||
session, ok = s.sseManager.get(sessionId)
|
||||
if !ok {
|
||||
s.logger.DebugContext(ctx, "sse session not available")
|
||||
}
|
||||
}
|
||||
|
||||
// check if client have `Mcp-Session-Id` header
|
||||
// if `Mcp-Session-Id` header is set, we are using v2025-03-26 since
|
||||
// previous version doesn't use this header.
|
||||
headerSessionId := r.Header.Get("Mcp-Session-Id")
|
||||
if headerSessionId != "" {
|
||||
protocolVersion = v20250326.PROTOCOL_VERSION
|
||||
}
|
||||
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||
|
||||
// retrieve sse session id, if applicable
|
||||
sessionId := r.URL.Query().Get("sessionId")
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -312,10 +386,10 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil))
|
||||
render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil))
|
||||
}
|
||||
|
||||
res, err := processMcpMessage(ctx, body, s, toolsetName)
|
||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName)
|
||||
// notifications will return empty string
|
||||
if res == nil {
|
||||
// Notifications do not expect a response
|
||||
@@ -327,11 +401,13 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
// retrieve sse session
|
||||
session, ok := s.sseManager.get(sessionId)
|
||||
if !ok {
|
||||
s.logger.DebugContext(ctx, "sse session not available")
|
||||
} else {
|
||||
// for v20250326, add the `Mcp-Session-Id` header
|
||||
if v == v20250326.PROTOCOL_VERSION {
|
||||
sessionId = uuid.New().String()
|
||||
w.Header().Set("Mcp-Session-Id", sessionId)
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
// queue sse event
|
||||
eventData, _ := json.Marshal(res)
|
||||
select {
|
||||
@@ -349,141 +425,66 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// processMcpMessage process the messages received from clients
|
||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, toolsetName string) (any, error) {
|
||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string) (string, any, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return newJSONRPCError("", mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
if protocolVersion == "" {
|
||||
protocolVersion = v20241105.PROTOCOL_VERSION
|
||||
}
|
||||
|
||||
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
|
||||
var baseMessage struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Id mcp.RequestId `json:"id,omitempty"`
|
||||
}
|
||||
if err = decodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
|
||||
var baseMessage jsonrpc.BaseMessage
|
||||
if err = util.DecodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
return newJSONRPCError(id, mcp.PARSE_ERROR, err.Error(), nil), err
|
||||
|
||||
// check if user is sending a batch request
|
||||
var a []any
|
||||
unmarshalErr := json.Unmarshal(body, &a)
|
||||
if unmarshalErr == nil {
|
||||
err = fmt.Errorf("not supporting batch requests")
|
||||
return "", jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
return "", jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if method is present
|
||||
if baseMessage.Method == "" {
|
||||
err = fmt.Errorf("method not found")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("method is: %s", baseMessage.Method))
|
||||
|
||||
// Check for JSON-RPC 2.0
|
||||
if baseMessage.Jsonrpc != mcp.JSONRPC_VERSION {
|
||||
if baseMessage.Jsonrpc != jsonrpc.JSONRPC_VERSION {
|
||||
err = fmt.Errorf("invalid json-rpc version")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if message is a notification
|
||||
if baseMessage.Id == nil {
|
||||
var notification mcp.JSONRPCNotification
|
||||
if err = json.Unmarshal(body, ¬ification); err != nil {
|
||||
err = fmt.Errorf("invalid notification request: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return nil, nil
|
||||
err := mcp.NotificationHandler(ctx, body)
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch baseMessage.Method {
|
||||
case "initialize":
|
||||
var req mcp.InitializeRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp initialize request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
result := mcp.Initialize(s.version)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
case "tools/list":
|
||||
var req mcp.ListToolsRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
case mcputil.INITIALIZE:
|
||||
res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
|
||||
if err != nil {
|
||||
return "", res, err
|
||||
}
|
||||
return v, res, err
|
||||
default:
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
result := mcp.ToolsList(toolset)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
case "tools/call":
|
||||
var req mcp.CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var data map[string]any
|
||||
if err = decodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
return newJSONRPCError(baseMessage.Id, mcp.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := mcp.ToolCall(ctx, tool, params)
|
||||
return mcp.JSONRPCResponse{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: baseMessage.Id,
|
||||
Result: result,
|
||||
}, nil
|
||||
default:
|
||||
err = fmt.Errorf("invalid method %s", baseMessage.Method)
|
||||
return newJSONRPCError(baseMessage.Id, mcp.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
// newJSONRPCError is the response sent back when an error has been encountered in mcp.
|
||||
func newJSONRPCError(id mcp.RequestId, code int, message string, data any) mcp.JSONRPCError {
|
||||
return mcp.JSONRPCError{
|
||||
Jsonrpc: mcp.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Error: mcp.McpError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
},
|
||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.tools, body)
|
||||
return "", res, err
|
||||
}
|
||||
}
|
||||
|
||||
125
internal/server/mcp/jsonrpc/jsonrpc.go
Normal file
125
internal/server/mcp/jsonrpc/jsonrpc.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package jsonrpc
|
||||
|
||||
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.
|
||||
const JSONRPC_VERSION = "2.0"
|
||||
|
||||
// Standard JSON-RPC error codes
|
||||
const (
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
)
|
||||
|
||||
// ProgressToken is used to associate progress notifications with the original request.
|
||||
type ProgressToken interface{}
|
||||
|
||||
// RequestId is a uniquely identifying ID for a request in JSON-RPC.
|
||||
// It can be any JSON-serializable value, typically a number or string.
|
||||
type RequestId interface{}
|
||||
|
||||
// Request represents a bidirectional message with method and parameters expecting a response.
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta struct {
|
||||
// If specified, the caller is requesting out-of-band progress
|
||||
// notifications for this request (as represented by
|
||||
// notifications/progress). The value of this parameter is an
|
||||
// opaque token that will be attached to any subsequent
|
||||
// notifications. The receiver is not obligated to provide these
|
||||
// notifications.
|
||||
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCRequest represents a request that expects a response.
|
||||
type JSONRPCRequest struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Request
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Notification is a one-way message requiring no response.
|
||||
type Notification struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCNotification represents a notification which does not expect a response.
|
||||
type JSONRPCNotification struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Notification
|
||||
}
|
||||
|
||||
// Result represents a response for the request query.
|
||||
type Result struct {
|
||||
// This result property is reserved by the protocol to allow clients and
|
||||
// servers to attach additional metadata to their responses.
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCResponse represents a successful (non-error) response to a request.
|
||||
type JSONRPCResponse struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Result interface{} `json:"result"`
|
||||
}
|
||||
|
||||
// Error represents the error content.
|
||||
type Error struct {
|
||||
// The error type that occurred.
|
||||
Code int `json:"code"`
|
||||
// A short description of the error. The message SHOULD be limited
|
||||
// to a concise single sentence.
|
||||
Message string `json:"message"`
|
||||
// Additional information about the error. The value of this member
|
||||
// is defined by the sender (e.g. detailed error information, nested errors etc.).
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCError represents a non-successful (error) response to a request.
|
||||
type JSONRPCError struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
// Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest
|
||||
type BaseMessage struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Id RequestId `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// NewError is the standard JSONRPC response sent back when an error has been encountered.
|
||||
func NewError(id RequestId, code int, message string, data any) JSONRPCError {
|
||||
return JSONRPCError{
|
||||
Jsonrpc: JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Error: Error{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Data: data,
|
||||
},
|
||||
}
|
||||
}
|
||||
99
internal/server/mcp/mcp.go
Normal file
99
internal/server/mcp/mcp.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util"
|
||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
|
||||
// Update the version used in InitializeResponse when this value is updated.
|
||||
const LATEST_PROTOCOL_VERSION = v20250326.PROTOCOL_VERSION
|
||||
|
||||
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
|
||||
var SUPPORTED_PROTOCOL_VERSIONS = []string{v20241105.PROTOCOL_VERSION, v20250326.PROTOCOL_VERSION}
|
||||
|
||||
// InitializeResponse runs capability negotiation and protocol version agreement.
|
||||
// This is the Initialization phase of the lifecycle for MCP client-server connections.
|
||||
// Always start with the latest protocol version supported.
|
||||
func InitializeResponse(ctx context.Context, id jsonrpc.RequestId, body []byte, toolboxVersion string) (any, string, error) {
|
||||
var req mcputil.InitializeRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp initialize request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), "", err
|
||||
}
|
||||
|
||||
var protocolVersion string
|
||||
v := req.Params.ProtocolVersion
|
||||
if slices.Contains(SUPPORTED_PROTOCOL_VERSIONS, v) {
|
||||
protocolVersion = v
|
||||
} else {
|
||||
protocolVersion = LATEST_PROTOCOL_VERSION
|
||||
}
|
||||
|
||||
toolsListChanged := false
|
||||
result := mcputil.InitializeResult{
|
||||
ProtocolVersion: protocolVersion,
|
||||
Capabilities: mcputil.ServerCapabilities{
|
||||
Tools: &mcputil.ListChanged{
|
||||
ListChanged: &toolsListChanged,
|
||||
},
|
||||
},
|
||||
ServerInfo: mcputil.Implementation{
|
||||
Name: mcputil.SERVER_NAME,
|
||||
Version: toolboxVersion,
|
||||
},
|
||||
}
|
||||
res := jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
return res, protocolVersion, nil
|
||||
}
|
||||
|
||||
// NotificationHandler process notifications request. It MUST NOT send a response.
|
||||
// Currently Toolbox does not process any notifications.
|
||||
func NotificationHandler(ctx context.Context, body []byte) error {
|
||||
var notification jsonrpc.JSONRPCNotification
|
||||
if err := json.Unmarshal(body, ¬ification); err != nil {
|
||||
return fmt.Errorf("invalid notification request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
// This is the Operation phase of the lifecycle for MCP client-server connections.
|
||||
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
|
||||
switch mcpVersion {
|
||||
case v20250326.PROTOCOL_VERSION:
|
||||
return v20250326.ProcessMethod(ctx, id, method, toolset, tools, body)
|
||||
case v20241105.PROTOCOL_VERSION:
|
||||
return v20241105.ProcessMethod(ctx, id, method, toolset, tools, body)
|
||||
default:
|
||||
err := fmt.Errorf("invalid protocol version: %s", mcpVersion)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
func Initialize(version string) InitializeResult {
|
||||
toolsListChanged := false
|
||||
result := InitializeResult{
|
||||
ProtocolVersion: LATEST_PROTOCOL_VERSION,
|
||||
Capabilities: ServerCapabilities{
|
||||
Tools: &ListChanged{
|
||||
ListChanged: &toolsListChanged,
|
||||
},
|
||||
},
|
||||
ServerInfo: Implementation{
|
||||
Name: SERVER_NAME,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToolsList return a ListToolsResult
|
||||
func ToolsList(toolset tools.Toolset) ListToolsResult {
|
||||
mcpManifest := toolset.McpManifest
|
||||
|
||||
result := ListToolsResult{
|
||||
Tools: mcpManifest,
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToolCall runs tool invocation and return a CallToolResult
|
||||
func ToolCall(ctx context.Context, tool tools.Tool, params tools.ParamValues) CallToolResult {
|
||||
res, err := tool.Invoke(ctx, params)
|
||||
if err != nil {
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return CallToolResult{Content: []TextContent{text}, IsError: true}
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range res {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||
} else {
|
||||
text.Text = string(dM)
|
||||
}
|
||||
content = append(content, text)
|
||||
}
|
||||
return CallToolResult{Content: content}
|
||||
}
|
||||
@@ -1,295 +0,0 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
const SERVER_NAME = "Toolbox"
|
||||
|
||||
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
|
||||
const LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
|
||||
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.
|
||||
const JSONRPC_VERSION = "2.0"
|
||||
|
||||
// Standard JSON-RPC error codes
|
||||
const (
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
)
|
||||
|
||||
// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError.
|
||||
type JSONRPCMessage interface{}
|
||||
|
||||
// ProgressToken is used to associate progress notifications with the original request.
|
||||
type ProgressToken interface{}
|
||||
|
||||
// Request represents a bidirectional message with method and parameters expecting a response.
|
||||
type Request struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta struct {
|
||||
// If specified, the caller is requesting out-of-band progress
|
||||
// notifications for this request (as represented by
|
||||
// notifications/progress). The value of this parameter is an
|
||||
// opaque token that will be attached to any subsequent
|
||||
// notifications. The receiver is not obligated to provide these
|
||||
// notifications.
|
||||
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Notification is a one-way message requiring no response.
|
||||
type Notification struct {
|
||||
Method string `json:"method"`
|
||||
Params struct {
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Result represents a response for the request query.
|
||||
type Result struct {
|
||||
// This result property is reserved by the protocol to allow clients and
|
||||
// servers to attach additional metadata to their responses.
|
||||
Meta map[string]interface{} `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// RequestId is a uniquely identifying ID for a request in JSON-RPC.
|
||||
// It can be any JSON-serializable value, typically a number or string.
|
||||
type RequestId interface{}
|
||||
|
||||
// JSONRPCRequest represents a request that expects a response.
|
||||
type JSONRPCRequest struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Request
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCNotification represents a notification which does not expect a response.
|
||||
type JSONRPCNotification struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Notification
|
||||
}
|
||||
|
||||
// JSONRPCResponse represents a successful (non-error) response to a request.
|
||||
type JSONRPCResponse struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Result interface{} `json:"result"`
|
||||
}
|
||||
|
||||
// McpError represents the error content.
|
||||
type McpError struct {
|
||||
// The error type that occurred.
|
||||
Code int `json:"code"`
|
||||
// A short description of the error. The message SHOULD be limited
|
||||
// to a concise single sentence.
|
||||
Message string `json:"message"`
|
||||
// Additional information about the error. The value of this member
|
||||
// is defined by the sender (e.g. detailed error information, nested errors etc.).
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCError represents a non-successful (error) response to a request.
|
||||
type JSONRPCError struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
Id RequestId `json:"id"`
|
||||
Error McpError `json:"error"`
|
||||
}
|
||||
|
||||
/* Empty result */
|
||||
|
||||
// EmptyResult represents a response that indicates success but carries no data.
|
||||
type EmptyResult Result
|
||||
|
||||
/* Initialization */
|
||||
|
||||
// Params to define MCP Client during initialize request.
|
||||
type InitializeParams struct {
|
||||
// The latest version of the Model Context Protocol that the client supports.
|
||||
// The client MAY decide to support older versions as well.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ClientCapabilities `json:"capabilities"`
|
||||
ClientInfo Implementation `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// InitializeRequest is sent from the client to the server when it first
|
||||
// connects, asking it to begin initialization.
|
||||
type InitializeRequest struct {
|
||||
Request
|
||||
Params InitializeParams `json:"params"`
|
||||
}
|
||||
|
||||
// InitializeResult is sent after receiving an initialize request from the
|
||||
// client.
|
||||
type InitializeResult struct {
|
||||
Result
|
||||
// The version of the Model Context Protocol that the server wants to use.
|
||||
// This may not match the version that the client requested. If the client cannot
|
||||
// support this version, it MUST disconnect.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo Implementation `json:"serverInfo"`
|
||||
// Instructions describing how to use the server and its features.
|
||||
//
|
||||
// This can be used by clients to improve the LLM's understanding of
|
||||
// available tools, resources, etc. It can be thought of like a "hint" to the model.
|
||||
// For example, this information MAY be added to the system prompt.
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
}
|
||||
|
||||
// InitializedNotification is sent from the client to the server after
|
||||
// initialization has finished.
|
||||
type InitializedNotification struct {
|
||||
Notification
|
||||
}
|
||||
|
||||
// ListChange represents whether the server supports notification for changes to the capabilities.
|
||||
type ListChanged struct {
|
||||
ListChanged *bool `json:"listChanged,omitempty"`
|
||||
}
|
||||
|
||||
// ClientCapabilities represents capabilities a client may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// client can define its own, additional capabilities.
|
||||
type ClientCapabilities struct {
|
||||
// Experimental, non-standard capabilities that the client supports.
|
||||
Experimental map[string]interface{} `json:"experimental,omitempty"`
|
||||
// Present if the client supports listing roots.
|
||||
Roots *ListChanged `json:"roots,omitempty"`
|
||||
// Present if the client supports sampling from an LLM.
|
||||
Sampling struct{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerCapabilities represents capabilities that a server may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// server can define its own, additional capabilities.
|
||||
type ServerCapabilities struct {
|
||||
Tools *ListChanged `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// Implementation describes the name and version of an MCP implementation.
|
||||
type Implementation struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
/* Pagination */
|
||||
|
||||
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||
type Cursor string
|
||||
|
||||
type PaginatedRequest struct {
|
||||
Request
|
||||
Params struct {
|
||||
// An opaque token representing the current pagination position.
|
||||
// If provided, the server should return results starting after this cursor.
|
||||
Cursor Cursor `json:"cursor,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type PaginatedResult struct {
|
||||
Result
|
||||
// An opaque token representing the pagination position after the last returned result.
|
||||
// If present, there may be more results available.
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
/* Tools */
|
||||
|
||||
// Sent from the client to request a list of tools the server has.
|
||||
type ListToolsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a tools/list request from the client.
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []tools.McpManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// Used by the client to invoke a tool provided by the server.
|
||||
type CallToolRequest struct {
|
||||
Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// The sender or recipient of messages and data in a conversation.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Base for objects that include optional annotations for the client.
|
||||
// The client can use annotations to inform how objects are used or displayed
|
||||
type Annotated struct {
|
||||
Annotations *struct {
|
||||
// Describes who the intended customer of this object or data is.
|
||||
// It can include multiple entries to indicate content useful for multiple
|
||||
// audiences (e.g., `["user", "assistant"]`).
|
||||
Audience []Role `json:"audience,omitempty"`
|
||||
// Describes how important this data is for operating the server.
|
||||
//
|
||||
// A value of 1 means "most important," and indicates that the data is
|
||||
// effectively required, while 0 means "least important," and indicates that
|
||||
// the data is entirely optional.
|
||||
//
|
||||
// @TJS-type number
|
||||
// @minimum 0
|
||||
// @maximum 1
|
||||
Priority float64 `json:"priority,omitempty"`
|
||||
} `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// TextContent represents text provided to or from an LLM.
|
||||
type TextContent struct {
|
||||
Annotated
|
||||
Type string `json:"type"`
|
||||
// The text content of the message.
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// The server's response to a tool call.
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||
// For Toolbox, we will only be sending TextContent
|
||||
Content []TextContent `json:"content"`
|
||||
// Whether the tool call ended in an error.
|
||||
// If not set, this is assumed to be false (the call was successful).
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
96
internal/server/mcp/util/lifecycle.go
Normal file
96
internal/server/mcp/util/lifecycle.go
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package util
|
||||
|
||||
import "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
|
||||
const (
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
SERVER_NAME = "Toolbox"
|
||||
// methods that are supported
|
||||
INITIALIZE = "initialize"
|
||||
)
|
||||
|
||||
/* Initialization */
|
||||
|
||||
// Params to define MCP Client during initialize request.
|
||||
type InitializeParams struct {
|
||||
// The latest version of the Model Context Protocol that the client supports.
|
||||
// The client MAY decide to support older versions as well.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ClientCapabilities `json:"capabilities"`
|
||||
ClientInfo Implementation `json:"clientInfo"`
|
||||
}
|
||||
|
||||
// InitializeRequest is sent from the client to the server when it first
|
||||
// connects, asking it to begin initialization.
|
||||
type InitializeRequest struct {
|
||||
jsonrpc.Request
|
||||
Params InitializeParams `json:"params"`
|
||||
}
|
||||
|
||||
// InitializeResult is sent after receiving an initialize request from the
|
||||
// client.
|
||||
type InitializeResult struct {
|
||||
jsonrpc.Result
|
||||
// The version of the Model Context Protocol that the server wants to use.
|
||||
// This may not match the version that the client requested. If the client cannot
|
||||
// support this version, it MUST disconnect.
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities ServerCapabilities `json:"capabilities"`
|
||||
ServerInfo Implementation `json:"serverInfo"`
|
||||
// Instructions describing how to use the server and its features.
|
||||
//
|
||||
// This can be used by clients to improve the LLM's understanding of
|
||||
// available tools, resources, etc. It can be thought of like a "hint" to the model.
|
||||
// For example, this information MAY be added to the system prompt.
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
}
|
||||
|
||||
// InitializedNotification is sent from the client to the server after
|
||||
// initialization has finished.
|
||||
type InitializedNotification struct {
|
||||
jsonrpc.Notification
|
||||
}
|
||||
|
||||
// ListChange represents whether the server supports notification for changes to the capabilities.
|
||||
type ListChanged struct {
|
||||
ListChanged *bool `json:"listChanged,omitempty"`
|
||||
}
|
||||
|
||||
// ClientCapabilities represents capabilities a client may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// client can define its own, additional capabilities.
|
||||
type ClientCapabilities struct {
|
||||
// Experimental, non-standard capabilities that the client supports.
|
||||
Experimental map[string]interface{} `json:"experimental,omitempty"`
|
||||
// Present if the client supports listing roots.
|
||||
Roots *ListChanged `json:"roots,omitempty"`
|
||||
// Present if the client supports sampling from an LLM.
|
||||
Sampling struct{} `json:"sampling,omitempty"`
|
||||
}
|
||||
|
||||
// ServerCapabilities represents capabilities that a server may support. Known
|
||||
// capabilities are defined here, in this schema, but this is not a closed set: any
|
||||
// server can define its own, additional capabilities.
|
||||
type ServerCapabilities struct {
|
||||
Tools *ListChanged `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// Implementation describes the name and version of an MCP implementation.
|
||||
type Implementation struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
141
internal/server/mcp/v20241105/method.go
Normal file
141
internal/server/mcp/v20241105/method.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20241105
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
|
||||
switch method {
|
||||
case TOOLS_LIST:
|
||||
return toolsListHandler(id, toolset, body)
|
||||
case TOOLS_CALL:
|
||||
return toolsCallHandler(ctx, id, tools, body)
|
||||
default:
|
||||
err := fmt.Errorf("invalid method %s", method)
|
||||
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
|
||||
var req ListToolsRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := ListToolsResult{
|
||||
Tools: toolset.McpManifest,
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var req CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, params)
|
||||
if err != nil {
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range results {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||
} else {
|
||||
text.Text = string(dM)
|
||||
}
|
||||
content = append(content, text)
|
||||
}
|
||||
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: content},
|
||||
}, nil
|
||||
}
|
||||
137
internal/server/mcp/v20241105/types.go
Normal file
137
internal/server/mcp/v20241105/types.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20241105
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
const SERVER_NAME = "Toolbox"
|
||||
|
||||
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
|
||||
const PROTOCOL_VERSION = "2024-11-05"
|
||||
|
||||
// methods that are supported.
|
||||
const (
|
||||
TOOLS_LIST = "tools/list"
|
||||
TOOLS_CALL = "tools/call"
|
||||
)
|
||||
|
||||
/* Empty result */
|
||||
|
||||
// EmptyResult represents a response that indicates success but carries no data.
|
||||
type EmptyResult jsonrpc.Result
|
||||
|
||||
/* Pagination */
|
||||
|
||||
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||
type Cursor string
|
||||
|
||||
type PaginatedRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
// An opaque token representing the current pagination position.
|
||||
// If provided, the server should return results starting after this cursor.
|
||||
Cursor Cursor `json:"cursor,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type PaginatedResult struct {
|
||||
jsonrpc.Result
|
||||
// An opaque token representing the pagination position after the last returned result.
|
||||
// If present, there may be more results available.
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
/* Tools */
|
||||
|
||||
// Sent from the client to request a list of tools the server has.
|
||||
type ListToolsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a tools/list request from the client.
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []tools.McpManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// Used by the client to invoke a tool provided by the server.
|
||||
type CallToolRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// The sender or recipient of messages and data in a conversation.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Base for objects that include optional annotations for the client.
|
||||
// The client can use annotations to inform how objects are used or displayed
|
||||
type Annotated struct {
|
||||
Annotations *struct {
|
||||
// Describes who the intended customer of this object or data is.
|
||||
// It can include multiple entries to indicate content useful for multiple
|
||||
// audiences (e.g., `["user", "assistant"]`).
|
||||
Audience []Role `json:"audience,omitempty"`
|
||||
// Describes how important this data is for operating the server.
|
||||
//
|
||||
// A value of 1 means "most important," and indicates that the data is
|
||||
// effectively required, while 0 means "least important," and indicates that
|
||||
// the data is entirely optional.
|
||||
//
|
||||
// @TJS-type number
|
||||
// @minimum 0
|
||||
// @maximum 1
|
||||
Priority float64 `json:"priority,omitempty"`
|
||||
} `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// TextContent represents text provided to or from an LLM.
|
||||
type TextContent struct {
|
||||
Annotated
|
||||
Type string `json:"type"`
|
||||
// The text content of the message.
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// The server's response to a tool call.
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
type CallToolResult struct {
|
||||
jsonrpc.Result
|
||||
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||
// For Toolbox, we will only be sending TextContent
|
||||
Content []TextContent `json:"content"`
|
||||
// Whether the tool call ended in an error.
|
||||
// If not set, this is assumed to be false (the call was successful).
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
141
internal/server/mcp/v20250326/method.go
Normal file
141
internal/server/mcp/v20250326/method.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20250326
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
|
||||
switch method {
|
||||
case TOOLS_LIST:
|
||||
return toolsListHandler(id, toolset, body)
|
||||
case TOOLS_CALL:
|
||||
return toolsCallHandler(ctx, id, tools, body)
|
||||
default:
|
||||
err := fmt.Errorf("invalid method %s", method)
|
||||
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
|
||||
var req ListToolsRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := ListToolsResult{
|
||||
Tools: toolset.McpManifest,
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var req CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
// Since MCP doesn't support auth, an empty map will be use every time.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, params)
|
||||
if err != nil {
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range results {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||
} else {
|
||||
text.Text = string(dM)
|
||||
}
|
||||
content = append(content, text)
|
||||
}
|
||||
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: content},
|
||||
}, nil
|
||||
}
|
||||
169
internal/server/mcp/v20250326/types.go
Normal file
169
internal/server/mcp/v20250326/types.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20250326
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
const SERVER_NAME = "Toolbox"
|
||||
|
||||
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
|
||||
const PROTOCOL_VERSION = "2025-03-26"
|
||||
|
||||
// methods that are supported.
|
||||
const (
|
||||
TOOLS_LIST = "tools/list"
|
||||
TOOLS_CALL = "tools/call"
|
||||
)
|
||||
|
||||
/* Empty result */
|
||||
|
||||
// EmptyResult represents a response that indicates success but carries no data.
|
||||
type EmptyResult jsonrpc.Result
|
||||
|
||||
/* Pagination */
|
||||
|
||||
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||
type Cursor string
|
||||
|
||||
type PaginatedRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
// An opaque token representing the current pagination position.
|
||||
// If provided, the server should return results starting after this cursor.
|
||||
Cursor Cursor `json:"cursor,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type PaginatedResult struct {
|
||||
jsonrpc.Result
|
||||
// An opaque token representing the pagination position after the last returned result.
|
||||
// If present, there may be more results available.
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
/* Tools */
|
||||
|
||||
// Sent from the client to request a list of tools the server has.
|
||||
type ListToolsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a tools/list request from the client.
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []tools.McpManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// Used by the client to invoke a tool provided by the server.
|
||||
type CallToolRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// The sender or recipient of messages and data in a conversation.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Base for objects that include optional annotations for the client.
|
||||
// The client can use annotations to inform how objects are used or displayed
|
||||
type Annotated struct {
|
||||
Annotations *struct {
|
||||
// Describes who the intended customer of this object or data is.
|
||||
// It can include multiple entries to indicate content useful for multiple
|
||||
// audiences (e.g., `["user", "assistant"]`).
|
||||
Audience []Role `json:"audience,omitempty"`
|
||||
// Describes how important this data is for operating the server.
|
||||
//
|
||||
// A value of 1 means "most important," and indicates that the data is
|
||||
// effectively required, while 0 means "least important," and indicates that
|
||||
// the data is entirely optional.
|
||||
//
|
||||
// @TJS-type number
|
||||
// @minimum 0
|
||||
// @maximum 1
|
||||
Priority float64 `json:"priority,omitempty"`
|
||||
} `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// TextContent represents text provided to or from an LLM.
|
||||
type TextContent struct {
|
||||
Annotated
|
||||
Type string `json:"type"`
|
||||
// The text content of the message.
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// The server's response to a tool call.
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
type CallToolResult struct {
|
||||
jsonrpc.Result
|
||||
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||
// For Toolbox, we will only be sending TextContent
|
||||
Content []TextContent `json:"content"`
|
||||
// Whether the tool call ended in an error.
|
||||
// If not set, this is assumed to be false (the call was successful).
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
// Additional properties describing a Tool to clients.
|
||||
//
|
||||
// NOTE: all properties in ToolAnnotations are **hints**.
|
||||
// They are not guaranteed to provide a faithful description of
|
||||
// tool behavior (including descriptive properties like `title`).
|
||||
//
|
||||
// Clients should never make tool use decisions based on ToolAnnotations
|
||||
// received from untrusted servers.
|
||||
type ToolAnnotations struct {
|
||||
// A human-readable title for the tool.
|
||||
Title string `json:"title,omitempty"`
|
||||
// If true, the tool does not modify its environment.
|
||||
// Default: false
|
||||
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
|
||||
// If true, the tool may perform destructive updates to its environment.
|
||||
// If false, the tool performs only additive updates.
|
||||
// (This property is meaningful only when `readOnlyHint == false`)
|
||||
// Default: true
|
||||
DestructiveHint bool `json:"destructiveHint,omitempty"`
|
||||
// If true, calling the tool repeatedly with the same arguments
|
||||
// will have no additional effect on the its environment.
|
||||
// (This property is meaningful only when `readOnlyHint == false`)
|
||||
// Default: false
|
||||
IdempotentHint bool `json:"idempotentHint,omitempty"`
|
||||
// If true, this tool may interact with an "open world" of external
|
||||
// entities. If false, the tool's domain of interaction is closed.
|
||||
// For example, the world of a web search tool is open, whereas that
|
||||
// of a memory tool is not.
|
||||
// Default: true
|
||||
OpenWorldHint bool `json:"openWorldHint,omitempty"`
|
||||
}
|
||||
@@ -25,16 +25,17 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const jsonrpcVersion = "2.0"
|
||||
const protocolVersion = "2024-11-05"
|
||||
const protocolVersion20241105 = "2024-11-05"
|
||||
const protocolVersion20250326 = "2025-03-26"
|
||||
const serverName = "Toolbox"
|
||||
|
||||
var tool1InputSchema = map[string]any{
|
||||
@@ -64,7 +65,7 @@ var tool3InputSchema = map[string]any{
|
||||
"required": []any{"my_array"},
|
||||
}
|
||||
|
||||
func TestMcpEndpoint(t *testing.T) {
|
||||
func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
||||
mockTools := []MockTool{tool1, tool2, tool3}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
r, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
|
||||
@@ -76,51 +77,20 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
name string
|
||||
url string
|
||||
isErr bool
|
||||
body mcp.JSONRPCRequest
|
||||
body jsonrpc.JSONRPCRequest
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "initialize",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "mcp-initialize",
|
||||
Request: mcp.Request{
|
||||
Method: "initialize",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"result": map[string]any{
|
||||
"protocolVersion": protocolVersion,
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "basic notification",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Request: mcp.Request{
|
||||
Method: "notification",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list",
|
||||
url: "/",
|
||||
body: mcp.JSONRPCRequest{
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
isErr: false,
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list",
|
||||
@@ -143,57 +113,14 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on tool1_only",
|
||||
url: "/tool1_only",
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-tool1",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-tool1",
|
||||
"result": map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "no_params",
|
||||
"inputSchema": tool1InputSchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on invalid tool set",
|
||||
url: "/foo",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-invalid-toolset",
|
||||
Request: mcp.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-invalid-toolset",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "toolset does not exist",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "missing-method",
|
||||
Request: mcp.Request{},
|
||||
Request: jsonrpc.Request{},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
@@ -204,34 +131,14 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "invalid-method",
|
||||
Request: mcp.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "invalid-method",
|
||||
"error": map[string]any{
|
||||
"code": -32601.0,
|
||||
"message": "invalid method foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid jsonrpc version",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: mcp.JSONRPCRequest{
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "1.0",
|
||||
Id: "invalid-jsonrpc-version",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
@@ -252,7 +159,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodPost, tc.url, bytes.NewBuffer(reqMarshal))
|
||||
resp, body, err := runRequest(ts, http.MethodPost, tc.url, bytes.NewBuffer(reqMarshal), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
@@ -275,6 +182,378 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func runInitializeLifecycle(t *testing.T, ts *httptest.Server, protocolVersion string, initializeWant map[string]any, idHeader bool) string {
|
||||
initializeRequestBody := map[string]any{
|
||||
"jsonrpc": jsonrpcVersion,
|
||||
"id": "mcp-initialize",
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": protocolVersion,
|
||||
},
|
||||
}
|
||||
reqMarshal, err := json.Marshal(initializeRequestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(reqMarshal), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
sessionId := resp.Header.Get("Mcp-Session-Id")
|
||||
if idHeader && sessionId == "" {
|
||||
t.Fatalf("Mcp-Session-Id header is expected")
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unexpected error unmarshalling body: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, initializeWant) {
|
||||
t.Fatalf("unexpected response: got %+v, want %+v", got, initializeWant)
|
||||
}
|
||||
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
initializeNotificationBody := map[string]any{
|
||||
"jsonrpc": jsonrpcVersion,
|
||||
"method": "notifications/initialized",
|
||||
}
|
||||
notiMarshal, err := json.Marshal(initializeNotificationBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of notifications body")
|
||||
}
|
||||
|
||||
_, _, err = runRequest(ts, http.MethodPost, "/", bytes.NewBuffer(notiMarshal), header)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
return sessionId
|
||||
}
|
||||
|
||||
func TestMcpEndpoint(t *testing.T) {
|
||||
mockTools := []MockTool{tool1, tool2, tool3}
|
||||
toolsMap, toolsets := setUpResources(t, mockTools)
|
||||
r, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
ts := runServer(r, false)
|
||||
defer ts.Close()
|
||||
|
||||
versTestCases := []struct {
|
||||
name string
|
||||
protocol string
|
||||
idHeader bool
|
||||
initWant map[string]any
|
||||
}{
|
||||
{
|
||||
name: "verson 2024-11-05",
|
||||
protocol: protocolVersion20241105,
|
||||
idHeader: false,
|
||||
initWant: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "verson 2025-03-26",
|
||||
protocol: protocolVersion20250326,
|
||||
idHeader: true,
|
||||
initWant: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, vtc := range versTestCases {
|
||||
t.Run(vtc.name, func(t *testing.T) {
|
||||
sessionId := runInitializeLifecycle(t, ts, vtc.protocol, vtc.initWant, vtc.idHeader)
|
||||
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
url string
|
||||
isErr bool
|
||||
body any
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "basic notification",
|
||||
url: "/",
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Request: jsonrpc.Request{
|
||||
Method: "notification",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list",
|
||||
url: "/",
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list",
|
||||
"result": map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "no_params",
|
||||
"inputSchema": tool1InputSchema,
|
||||
},
|
||||
map[string]any{
|
||||
"name": "some_params",
|
||||
"inputSchema": tool2InputSchema,
|
||||
},
|
||||
map[string]any{
|
||||
"name": "array_param",
|
||||
"description": "some description",
|
||||
"inputSchema": tool3InputSchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on tool1_only",
|
||||
url: "/tool1_only",
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-tool1",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-tool1",
|
||||
"result": map[string]any{
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"name": "no_params",
|
||||
"inputSchema": tool1InputSchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tools/list on invalid tool set",
|
||||
url: "/foo",
|
||||
isErr: true,
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "tools-list-invalid-toolset",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "tools-list-invalid-toolset",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "toolset does not exist",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "missing-method",
|
||||
Request: jsonrpc.Request{},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "missing-method",
|
||||
"error": map[string]any{
|
||||
"code": -32601.0,
|
||||
"message": "method not found",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "invalid-method",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "invalid-method",
|
||||
"error": map[string]any{
|
||||
"code": -32601.0,
|
||||
"message": "invalid method foo",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid jsonrpc version",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "1.0",
|
||||
Id: "invalid-jsonrpc-version",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "invalid-jsonrpc-version",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "invalid json-rpc version",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "batch requests",
|
||||
url: "/",
|
||||
isErr: true,
|
||||
body: []any{
|
||||
jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "1.0",
|
||||
Id: "batch-requests1",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "foo",
|
||||
},
|
||||
},
|
||||
jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: jsonrpcVersion,
|
||||
Id: "batch-requests2",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/list",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "not supporting batch requests",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reqMarshal, err := json.Marshal(tc.body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
if vtc.protocol == protocolVersion20250326 && len(header) == 0 {
|
||||
t.Fatalf("header is missing")
|
||||
}
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodPost, tc.url, bytes.NewBuffer(reqMarshal), header)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
// Notifications don't expect a response.
|
||||
if tc.want != nil {
|
||||
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unexpected error unmarshalling body: %s", err)
|
||||
}
|
||||
// for decode failure, a random uuid is generated in server
|
||||
if tc.want["id"] == nil {
|
||||
tc.want["id"] = got["id"]
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected response: got %+v, want %+v", got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteEndpoint(t *testing.T) {
|
||||
toolsMap, toolsets := map[string]tools.Tool{}, map[string]tools.Toolset{}
|
||||
r, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
ts := runServer(r, false)
|
||||
defer ts.Close()
|
||||
|
||||
resp, _, err := runRequest(ts, http.MethodDelete, "/", nil, nil)
|
||||
if resp.Status != "200 OK" {
|
||||
t.Fatalf("unexpected status: %s", resp.Status)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEndpoint(t *testing.T) {
|
||||
toolsMap, toolsets := map[string]tools.Tool{}, map[string]tools.Toolset{}
|
||||
r, shutdown := setUpServer(t, "mcp", toolsMap, toolsets)
|
||||
defer shutdown()
|
||||
ts := runServer(r, false)
|
||||
defer ts.Close()
|
||||
|
||||
resp, body, err := runRequest(ts, http.MethodGet, "/", nil, nil)
|
||||
if resp.Status != "405 Method Not Allowed" {
|
||||
t.Fatalf("unexpected status: %s", resp.Status)
|
||||
}
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(body, &got); err != nil {
|
||||
t.Fatalf("unexpected error unmarshalling body: %s", err)
|
||||
}
|
||||
want := "toolbox does not support streaming in streamable HTTP transport"
|
||||
if got["error"] != want {
|
||||
t.Fatalf("unexpected error message: %s", got["error"])
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSseEndpoint(t *testing.T) {
|
||||
r, shutdown := setUpServer(t, "mcp", nil, nil)
|
||||
defer shutdown()
|
||||
@@ -419,10 +698,7 @@ func TestStdioSession(t *testing.T) {
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
server := &Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: toolsMap, toolsets: toolsets}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -206,10 +205,7 @@ func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, er
|
||||
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
|
||||
srv := &http.Server{Addr: addr, Handler: r}
|
||||
|
||||
sseManager := &sseManager{
|
||||
mu: sync.RWMutex{},
|
||||
sseSessions: make(map[string]*sseSession),
|
||||
}
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
|
||||
@@ -16,13 +16,25 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
)
|
||||
|
||||
// DecodeJSON decodes a given reader into an interface using the json decoder.
|
||||
func DecodeJSON(r io.Reader, v interface{}) error {
|
||||
defer io.Copy(io.Discard, r) //nolint:errcheck
|
||||
d := json.NewDecoder(r)
|
||||
// specify JSON numbers should get parsed to json.Number instead of float64 by default.
|
||||
// This prevents loss between floats/ints.
|
||||
d.UseNumber()
|
||||
return d.Decode(v)
|
||||
}
|
||||
|
||||
var _ yaml.InterfaceUnmarshalerContext = &DelayedUnmarshaler{}
|
||||
|
||||
// DelayedUnmarshaler is struct that saves the provided unmarshal function
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
@@ -330,11 +330,17 @@ func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
}
|
||||
|
||||
func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
sessionId := tests.RunInitialize(t, "2024-11-05")
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody mcp.JSONRPCRequest
|
||||
requestBody jsonrpc.JSONRPCRequest
|
||||
requestHeader map[string]string
|
||||
want string
|
||||
}{
|
||||
@@ -342,10 +348,10 @@ func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
name: "MCP Invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "my-simple-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -361,10 +367,10 @@ func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
name: "MCP Invoke invalid tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invalid-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -378,10 +384,10 @@ func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
name: "MCP Invoke my-auth-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-without-parameter",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -404,7 +410,7 @@ func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
for k, v := range header {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
|
||||
125
tests/tool.go
125
tests/tool.go
@@ -24,7 +24,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
)
|
||||
|
||||
// RunToolGet runs the tool get endpoint
|
||||
@@ -550,13 +550,65 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, sele
|
||||
}
|
||||
}
|
||||
|
||||
// RunInitialize runs the initialize lifecycle for mcp to set up client-server connection
|
||||
func RunInitialize(t *testing.T, protocolVersion string) string {
|
||||
url := "http://127.0.0.1:5000/mcp"
|
||||
|
||||
initializeRequestBody := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": protocolVersion,
|
||||
},
|
||||
}
|
||||
reqMarshal, err := json.Marshal(initializeRequestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of body")
|
||||
}
|
||||
|
||||
resp, _ := runRequest(t, http.MethodPost, url, bytes.NewBuffer(reqMarshal), nil)
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
|
||||
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
sessionId := resp.Header.Get("Mcp-Session-Id")
|
||||
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
initializeNotificationBody := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
}
|
||||
notiMarshal, err := json.Marshal(initializeNotificationBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of notifications body")
|
||||
}
|
||||
|
||||
_, _ = runRequest(t, http.MethodPost, url, bytes.NewBuffer(notiMarshal), header)
|
||||
return sessionId
|
||||
}
|
||||
|
||||
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
|
||||
func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want string) {
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody mcp.JSONRPCRequest
|
||||
requestBody jsonrpc.JSONRPCRequest
|
||||
requestHeader map[string]string
|
||||
want string
|
||||
}{
|
||||
@@ -564,10 +616,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke my-param-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "my-param-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -584,10 +636,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke invalid tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invalid-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -601,10 +653,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke my-param-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-without-parameter",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -618,10 +670,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke my-param-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-insufficient-parameter",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -635,10 +687,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke my-auth-required-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke my-auth-required-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -652,10 +704,10 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
name: "MCP Invoke my-fail-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: mcp.JSONRPCRequest{
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke-fail-tool",
|
||||
Request: mcp.Request{
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
@@ -672,24 +724,8 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of request body")
|
||||
}
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read request body: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, respBody := runRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), header)
|
||||
got := string(bytes.TrimSpace(respBody))
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
@@ -698,3 +734,28 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want st
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runRequest(t *testing.T, method, url string, body io.Reader, header map[string]string) (*http.Response, []byte) {
|
||||
// Send request
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range header {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read request body: %s", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
return resp, respBody
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user