diff --git a/internal/server/mcp.go b/internal/server/mcp.go index 3adac31ab7..65ace06d66 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -36,9 +36,11 @@ import ( v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105" v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326" "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" ) type sseSession struct { @@ -116,6 +118,55 @@ type stdioSession struct { writer io.Writer } +// traceContextCarrier implements propagation.TextMapCarrier for extracting trace context from _meta +type traceContextCarrier map[string]string + +func (c traceContextCarrier) Get(key string) string { + return c[key] +} + +func (c traceContextCarrier) Set(key, value string) { + c[key] = value +} + +func (c traceContextCarrier) Keys() []string { + keys := make([]string, 0, len(c)) + for k := range c { + keys = append(keys, k) + } + return keys +} + +// extractTraceContext extracts W3C Trace Context from params._meta +func extractTraceContext(ctx context.Context, body []byte) context.Context { + // Try to parse the request to extract _meta + var req struct { + Params struct { + Meta struct { + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params,omitempty"` + } + + if err := json.Unmarshal(body, &req); err != nil { + return ctx + } + + // If traceparent is present, extract the context + if req.Params.Meta.Traceparent != "" { + carrier := traceContextCarrier{ + "traceparent": req.Params.Meta.Traceparent, + } + if req.Params.Meta.Tracestate != "" { + carrier["tracestate"] = req.Params.Meta.Tracestate + } + return otel.GetTextMapPropagator().Extract(ctx, carrier) + } + + return ctx +} + func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession { stdioSession := &stdioSession{ server: s, @@ -142,18 +193,29 @@ func (s *stdioSession) readInputStream(ctx context.Context) error { } return err } - v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", "", nil) + // This ensures the transport span becomes a child of the client span + msgCtx := extractTraceContext(ctx, []byte(line)) + + // Create span for STDIO transport + msgCtx, span := s.server.instrumentation.Tracer.Start(msgCtx, "toolbox/server/mcp/stdio", + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + v, res, err := processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "") if err != nil { // errors during the processing of message will generate a valid MCP Error response. // server can continue to run. - s.server.logger.ErrorContext(ctx, err.Error()) + s.server.logger.ErrorContext(msgCtx, err.Error()) + span.SetStatus(codes.Error, err.Error()) } + if v != "" { s.protocol = v } // no responses for notifications if res != nil { - if err = s.write(ctx, res); err != nil { + if err = s.write(msgCtx, res); err != nil { return err } } @@ -239,7 +301,9 @@ func mcpRouter(s *Server) (chi.Router, error) { // sseHandler handles sse initialization and message. func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse") + ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse", + trace.WithSpanKind(trace.SpanKindServer), + ) r = r.WithContext(ctx) sessionId := uuid.New().String() @@ -335,9 +399,27 @@ func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) { func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp") + ctx := r.Context() + ctx = util.WithLogger(ctx, s.logger) + + // Read body first so we can extract trace context + body, err := io.ReadAll(r.Body) + if err != nil { + // Generate a new uuid if unable to decode + id := uuid.New().String() + s.logger.DebugContext(ctx, err.Error()) + render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) + return + } + + // This ensures the transport span becomes a child of the client span + ctx = extractTraceContext(ctx, body) + + // Create span for HTTP transport + ctx, span := s.instrumentation.Tracer.Start(ctx, "toolbox/server/mcp/http", + trace.WithSpanKind(trace.SpanKindServer), + ) r = r.WithContext(ctx) - ctx = util.WithLogger(r.Context(), s.logger) var sessionId, protocolVersion string var session *sseSession @@ -379,7 +461,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("toolset_name", toolsetName)) - var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) @@ -398,17 +479,9 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { ) }() - // Read and returns a body from io.Reader - body, err := io.ReadAll(r.Body) - if err != nil { - // Generate a new uuid if unable to decode - id := uuid.New().String() - s.logger.DebugContext(ctx, err.Error()) - render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) - return - } + networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor) - v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header) + v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header, networkProtocolVersion) if err != nil { s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error()) } @@ -458,7 +531,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { } // processMcpMessage process the messages received from clients -func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header) (string, any, error) { +func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header, networkProtocolVersion string) (string, any, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err @@ -494,31 +567,95 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } + // Create method-specific span with semantic conventions + // Note: Trace context is already extracted and set in ctx by the caller + ctx, span := s.instrumentation.Tracer.Start(ctx, baseMessage.Method, + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + // Determine network transport and protocol based on header presence + networkTransport := "pipe" // default for stdio + networkProtocolName := "stdio" + if header != nil { + networkTransport = "tcp" // HTTP/SSE transport + networkProtocolName = "http" + } + + // Set required semantic attributes for span according to OTEL MCP semcov + // ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/mcp/#server + span.SetAttributes( + attribute.String("mcp.method.name", baseMessage.Method), + attribute.String("network.transport", networkTransport), + attribute.String("network.protocol.name", networkProtocolName), + ) + + // Set network protocol version if available + if networkProtocolVersion != "" { + span.SetAttributes(attribute.String("network.protocol.version", networkProtocolVersion)) + } + + // Set MCP protocol version if available + if protocolVersion != "" { + span.SetAttributes(attribute.String("mcp.protocol.version", protocolVersion)) + } + + // Set request ID + if baseMessage.Id != nil { + span.SetAttributes(attribute.String("jsonrpc.request.id", fmt.Sprintf("%v", baseMessage.Id))) + } + + // Set toolset name + span.SetAttributes(attribute.String("toolset.name", toolsetName)) + // Check if message is a notification if baseMessage.Id == nil { err := mcp.NotificationHandler(ctx, body) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } return "", nil, err } + // Process the method switch baseMessage.Method { case mcputil.INITIALIZE: - res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) + result, version, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) if err != nil { - return "", res, err + span.SetStatus(codes.Error, err.Error()) + if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok { + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + } + return "", result, err } - return v, res, err + span.SetAttributes(attribute.String("mcp.protocol.version", version)) + return version, result, err default: toolset, ok := s.ResourceMgr.GetToolset(toolsetName) if !ok { - err = fmt.Errorf("toolset does not exist") - return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + err := fmt.Errorf("toolset does not exist") + rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil) + span.SetStatus(codes.Error, err.Error()) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + return "", rpcErr, err } promptset, ok := s.ResourceMgr.GetPromptset(promptsetName) if !ok { - err = fmt.Errorf("promptset does not exist") - return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err + err := fmt.Errorf("promptset does not exist") + rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil) + span.SetStatus(codes.Error, err.Error()) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + return "", rpcErr, err } - res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header) - return "", res, err + result, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + // Set error.type based on JSON-RPC error code + if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok { + span.SetAttributes(attribute.Int("jsonrpc.error.code", rpcErr.Error.Code)) + span.SetAttributes(attribute.String("error.type", rpcErr.Error.String())) + } + } + return "", result, err } } diff --git a/internal/server/mcp/jsonrpc/jsonrpc.go b/internal/server/mcp/jsonrpc/jsonrpc.go index 7099ea8a63..8a4aaaf15b 100644 --- a/internal/server/mcp/jsonrpc/jsonrpc.go +++ b/internal/server/mcp/jsonrpc/jsonrpc.go @@ -45,6 +45,9 @@ type Request struct { // notifications. The receiver is not obligated to provide these // notifications. ProgressToken ProgressToken `json:"progressToken,omitempty"` + // W3C Trace Context fields for distributed tracing + Traceparent string `json:"traceparent,omitempty"` + Tracestate string `json:"tracestate,omitempty"` } `json:"_meta,omitempty"` } `json:"params,omitempty"` } @@ -97,6 +100,24 @@ type Error struct { Data interface{} `json:"data,omitempty"` } +// String returns the error type as a string based on the error code. +func (e Error) String() string { + switch e.Code { + case METHOD_NOT_FOUND: + return "method_not_found" + case INVALID_PARAMS: + return "invalid_params" + case INTERNAL_ERROR: + return "internal_error" + case PARSE_ERROR: + return "parse_error" + case INVALID_REQUEST: + return "invalid_request" + default: + return "jsonrpc_error" + } +} + // JSONRPCError represents a non-successful (error) response to a request. type JSONRPCError struct { Jsonrpc string `json:"jsonrpc"` diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index 0dd6943734..4684f4687c 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -101,6 +103,14 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -310,6 +320,11 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 22183d45d9..24c61fd617 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -101,6 +103,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -309,6 +320,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 24312d2da9..b6cb45059b 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -94,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -303,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 408fd0303c..2d59554c55 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -28,6 +28,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" ) // ProcessMethod returns a response for the request. @@ -94,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName)) + span.SetAttributes( + attribute.String("gen_ai.tool.name", toolName), + attribute.String("gen_ai.operation.name", "execute_tool"), + ) + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) @@ -303,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) + + // Update span name and set gen_ai attributes + span := trace.SpanFromContext(ctx) + span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName)) + span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName)) + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName)