removing ssz-only flag ( reverting feature) and fix accept header middleware (#15433)

* removing ssz-only flag

* gaz

* reverting other uses of sszonly

* gaz

* adding kasey and radek's suggestions

* update changelog

* adding test

* radek advice with new headers and tests

* adding logs and fixing comments

* adding logs and fixing comments

* gaz

* Update validator/client/beacon-api/rest_handler_client.go

Co-authored-by: Radosław Kapka <rkapka@wp.pl>

* Update api/apiutil/header.go

Co-authored-by: Radosław Kapka <rkapka@wp.pl>

* Update api/apiutil/header.go

Co-authored-by: Radosław Kapka <rkapka@wp.pl>

* radek's comments

* adding another failing case based on radek's suggestion

* another unit test

---------

Co-authored-by: Radosław Kapka <rkapka@wp.pl>
This commit is contained in:
james-prysm
2025-07-22 11:06:51 -05:00
committed by GitHub
parent c21fae239f
commit 77958022e7
17 changed files with 394 additions and 152 deletions

View File

@@ -2,18 +2,28 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["common.go"],
srcs = [
"common.go",
"header.go",
],
importpath = "github.com/OffchainLabs/prysm/v6/api/apiutil",
visibility = ["//visibility:public"],
deps = ["//consensus-types/primitives:go_default_library"],
deps = [
"//consensus-types/primitives:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["common_test.go"],
srcs = [
"common_test.go",
"header_test.go",
],
embed = [":go_default_library"],
deps = [
"//consensus-types/primitives:go_default_library",
"//testing/assert:go_default_library",
"//testing/require:go_default_library",
],
)

122
api/apiutil/header.go Normal file
View File

@@ -0,0 +1,122 @@
package apiutil
import (
"mime"
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
type mediaRange struct {
mt string // canonicalised mediatype, e.g. "application/json"
q float64 // quality factor (01)
raw string // original string useful for logging/debugging
spec int // 2=exact, 1=type/*, 0=*/*
}
func parseMediaRange(field string) (mediaRange, bool) {
field = strings.TrimSpace(field)
mt, params, err := mime.ParseMediaType(field)
if err != nil {
log.WithError(err).Debug("Failed to parse header field")
return mediaRange{}, false
}
r := mediaRange{mt: mt, q: 1, spec: 2, raw: field}
if qs, ok := params["q"]; ok {
v, err := strconv.ParseFloat(qs, 64)
if err != nil || v < 0 || v > 1 {
log.WithField("q", qs).Debug("Invalid quality factor (01)")
return mediaRange{}, false // skip invalid entry
}
r.q = v
}
switch {
case mt == "*/*":
r.spec = 0
case strings.HasSuffix(mt, "/*"):
r.spec = 1
}
return r, true
}
func hasExplicitQ(r mediaRange) bool {
return strings.Contains(strings.ToLower(r.raw), ";q=")
}
// ParseAccept returns media ranges sorted by q (desc) then specificity.
func ParseAccept(header string) []mediaRange {
if header == "" {
return []mediaRange{{mt: "*/*", q: 1, spec: 0, raw: "*/*"}}
}
var out []mediaRange
for _, field := range strings.Split(header, ",") {
if r, ok := parseMediaRange(field); ok {
out = append(out, r)
}
}
sort.SliceStable(out, func(i, j int) bool {
ei, ej := hasExplicitQ(out[i]), hasExplicitQ(out[j])
if ei != ej {
return ei // explicit beats implicit
}
if out[i].q != out[j].q {
return out[i].q > out[j].q
}
return out[i].spec > out[j].spec
})
return out
}
// Matches reports whether content type is acceptable per the header.
func Matches(header, ct string) bool {
for _, r := range ParseAccept(header) {
switch {
case r.q == 0:
continue
case r.mt == "*/*":
return true
case strings.HasSuffix(r.mt, "/*"):
if strings.HasPrefix(ct, r.mt[:len(r.mt)-1]) {
return true
}
case r.mt == ct:
return true
}
}
return false
}
// Negotiate selects the best server type according to the header.
// Returns the chosen type and true, or "", false when nothing matches.
func Negotiate(header string, serverTypes []string) (string, bool) {
for _, r := range ParseAccept(header) {
if r.q == 0 {
continue
}
for _, s := range serverTypes {
if Matches(r.mt, s) {
return s, true
}
}
}
return "", false
}
// PrimaryAcceptMatches only checks if the first accept matches
func PrimaryAcceptMatches(header, produced string) bool {
for _, r := range ParseAccept(header) {
if r.q == 0 {
continue // explicitly unacceptable skip
}
return Matches(r.mt, produced)
}
return false
}

174
api/apiutil/header_test.go Normal file
View File

@@ -0,0 +1,174 @@
package apiutil
import (
"testing"
"github.com/OffchainLabs/prysm/v6/testing/require"
)
func TestParseAccept(t *testing.T) {
type want struct {
mt string
q float64
spec int
}
cases := []struct {
name string
header string
want []want
}{
{
name: "empty header becomes */*;q=1",
header: "",
want: []want{{mt: "*/*", q: 1, spec: 0}},
},
{
name: "quality ordering then specificity",
header: "application/json;q=0.2, */*;q=0.1, application/xml;q=0.5, text/*;q=0.5",
want: []want{
{mt: "application/xml", q: 0.5, spec: 2},
{mt: "text/*", q: 0.5, spec: 1},
{mt: "application/json", q: 0.2, spec: 2},
{mt: "*/*", q: 0.1, spec: 0},
},
},
{
name: "invalid pieces are skipped",
header: "text/plain; q=boom, application/json",
want: []want{{mt: "application/json", q: 1, spec: 2}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := ParseAccept(tc.header)
gotProjected := make([]want, len(got))
for i, g := range got {
gotProjected[i] = want{mt: g.mt, q: g.q, spec: g.spec}
}
require.DeepEqual(t, gotProjected, tc.want)
})
}
}
func TestMatches(t *testing.T) {
cases := []struct {
name string
accept string
ct string
matches bool
}{
{"exact match", "application/json", "application/json", true},
{"type wildcard", "application/*;q=0.8", "application/xml", true},
{"global wildcard", "*/*;q=0.1", "image/png", true},
{"explicitly unacceptable (q=0)", "text/*;q=0", "text/plain", false},
{"no match", "image/png", "application/json", false},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := Matches(tc.accept, tc.ct)
require.Equal(t, tc.matches, got)
})
}
}
func TestNegotiate(t *testing.T) {
cases := []struct {
name string
accept string
serverTypes []string
wantType string
ok bool
}{
{
name: "highest quality wins",
accept: "application/json;q=0.8,application/xml;q=0.9",
serverTypes: []string{"application/json", "application/xml"},
wantType: "application/xml",
ok: true,
},
{
name: "wildcard matches first server type",
accept: "*/*;q=0.5",
serverTypes: []string{"application/octet-stream", "application/json"},
wantType: "application/octet-stream",
ok: true,
},
{
name: "no acceptable type",
accept: "image/png",
serverTypes: []string{"application/json"},
wantType: "",
ok: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, ok := Negotiate(tc.accept, tc.serverTypes)
require.Equal(t, tc.ok, ok)
require.Equal(t, tc.wantType, got)
})
}
}
func TestPrimaryAcceptMatches(t *testing.T) {
tests := []struct {
name string
accept string
produced string
expect bool
}{
{
name: "prefers json",
accept: "application/json;q=0.9,application/xml",
produced: "application/json",
expect: true,
},
{
name: "wildcard application beats other wildcard",
accept: "application/*;q=0.2,*/*;q=0.1",
produced: "application/xml",
expect: true,
},
{
name: "json wins",
accept: "application/xml;q=0.8,application/json;q=0.9",
produced: "application/json",
expect: true,
},
{
name: "json loses",
accept: "application/xml;q=0.8,application/json;q=0.9,application/octet-stream;q=0.99",
produced: "application/json",
expect: false,
},
{
name: "json wins with non q option",
accept: "application/xml;q=0.8,image/png,application/json;q=0.9",
produced: "application/json",
expect: true,
},
{
name: "json not primary",
accept: "image/png,application/json",
produced: "application/json",
expect: false,
},
{
name: "absent header",
accept: "",
produced: "text/plain",
expect: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := PrimaryAcceptMatches(tc.accept, tc.produced)
require.Equal(t, got, tc.expect)
})
}
}

View File

@@ -10,6 +10,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//api:go_default_library",
"//api/apiutil:go_default_library",
"@com_github_rs_cors//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
],

View File

@@ -7,6 +7,7 @@ import (
"strings"
"github.com/OffchainLabs/prysm/v6/api"
"github.com/OffchainLabs/prysm/v6/api/apiutil"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
)
@@ -74,42 +75,10 @@ func ContentTypeHandler(acceptedMediaTypes []string) Middleware {
func AcceptHeaderHandler(serverAcceptedTypes []string) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
acceptHeader := r.Header.Get("Accept")
// header is optional and should skip if not provided
if acceptHeader == "" {
next.ServeHTTP(w, r)
if _, ok := apiutil.Negotiate(r.Header.Get("Accept"), serverAcceptedTypes); !ok {
http.Error(w, "Not Acceptable", http.StatusNotAcceptable)
return
}
accepted := false
acceptTypes := strings.Split(acceptHeader, ",")
// follows rules defined in https://datatracker.ietf.org/doc/html/rfc2616#section-14.1
for _, acceptType := range acceptTypes {
acceptType = strings.TrimSpace(acceptType)
if acceptType == "*/*" {
accepted = true
break
}
for _, serverAcceptedType := range serverAcceptedTypes {
if strings.HasPrefix(acceptType, serverAcceptedType) {
accepted = true
break
}
if acceptType != "/*" && strings.HasSuffix(acceptType, "/*") && strings.HasPrefix(serverAcceptedType, acceptType[:len(acceptType)-2]) {
accepted = true
break
}
}
if accepted {
break
}
}
if !accepted {
http.Error(w, fmt.Sprintf("Not Acceptable: %s", acceptHeader), http.StatusNotAcceptable)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,7 @@
### Removed
- Partially reverting pr #15390 removing the `ssz-only` debug flag until there is a real usecase for the flag
### Added
- Added new PRYSM_API_OVERRIDE_ACCEPT environment variable to override ssz accept header as a replacement to flag

View File

@@ -51,7 +51,6 @@ type Flags struct {
EnableExperimentalAttestationPool bool // EnableExperimentalAttestationPool enables an experimental attestation pool design.
DisableDutiesV2 bool // DisableDutiesV2 sets validator client to use the get Duties endpoint
EnableWeb bool // EnableWeb enables the webui on the validator client
SSZOnly bool // SSZOnly forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled (useful for debugging)
// Logging related toggles.
DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected.
EnableFullSSZDataLogging bool // Enables logging for full ssz data on rejected gossip messages
@@ -340,10 +339,6 @@ func ConfigureValidator(ctx *cli.Context) error {
logEnabled(EnableWebFlag)
cfg.EnableWeb = true
}
if ctx.Bool(SSZOnly.Name) {
logEnabled(SSZOnly)
cfg.SSZOnly = true
}
cfg.KeystoreImportDebounceInterval = ctx.Duration(dynamicKeyReloadDebounceInterval.Name)
Init(cfg)

View File

@@ -197,12 +197,6 @@ var (
Usage: "(Work in progress): Enables the web portal for the validator client.",
Value: false,
}
// SSZOnly forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled
SSZOnly = &cli.BoolFlag{
Name: "ssz-only",
Usage: "(debug): Forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled",
}
)
// devModeFlags holds list of flags that are set when development mode is on.
@@ -225,7 +219,6 @@ var ValidatorFlags = append(deprecatedFlags, []cli.Flag{
EnableBeaconRESTApi,
DisableDutiesV2,
EnableWebFlag,
SSZOnly,
}...)
// E2EValidatorFlags contains a list of the validator feature flags to be tested in E2E.

View File

@@ -4,6 +4,10 @@ import (
"testing"
)
const (
EnvNameOverrideAccept = "PRYSM_API_OVERRIDE_ACCEPT"
)
// SetupTestConfigCleanup preserves configurations allowing to modify them within tests without any
// restrictions, everything is restored after the test.
func SetupTestConfigCleanup(t testing.TB) {

View File

@@ -248,9 +248,6 @@ func (v *ValidatorNode) Start(ctx context.Context) error {
args = append(args,
fmt.Sprintf("--%s=http://localhost:%d", flags.BeaconRESTApiProviderFlag.Name, beaconRestApiPort),
fmt.Sprintf("--%s", features.EnableBeaconRESTApi.Name))
if v.config.UseSSZOnly {
args = append(args, fmt.Sprintf("--%s", features.SSZOnly.Name))
}
}
// Only apply e2e flags to the current branch. New flags may not exist in previous release.

View File

@@ -25,14 +25,14 @@ func TestEndToEnd_MinimalConfig_Web3Signer_PersistentKeys(t *testing.T) {
e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithRemoteSignerAndPersistentKeysFile()).run()
}
func TestEndToEnd_MinimalConfig_ValidatorRESTApi(t *testing.T) {
e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi()).run()
}
func TestEndToEnd_MinimalConfig_ValidatorRESTApi_SSZ(t *testing.T) {
e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi(), types.WithSSZOnly()).run()
}
func TestEndToEnd_MinimalConfig_ValidatorRESTApi(t *testing.T) {
e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi()).run()
}
func TestEndToEnd_ScenarioRun_EEOffline(t *testing.T) {
t.Skip("TODO(#10242) Prysm is current unable to handle an offline e2e")
cfg := types.InitForkCfg(version.Bellatrix, version.Deneb, params.E2ETestConfig())

View File

@@ -11,9 +11,11 @@ go_library(
importpath = "github.com/OffchainLabs/prysm/v6/testing/endtoend/types",
visibility = ["//testing/endtoend:__subpackages__"],
deps = [
"//api:go_default_library",
"//config/params:go_default_library",
"//consensus-types/primitives:go_default_library",
"//runtime/version:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
],
)

View File

@@ -6,9 +6,11 @@ import (
"context"
"os"
"github.com/OffchainLabs/prysm/v6/api"
"github.com/OffchainLabs/prysm/v6/config/params"
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
"github.com/OffchainLabs/prysm/v6/runtime/version"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
@@ -51,18 +53,20 @@ func WithValidatorRESTApi() E2EConfigOpt {
}
}
func WithSSZOnly() E2EConfigOpt {
return func(cfg *E2EConfig) {
cfg.UseSSZOnly = true
}
}
func WithBuilder() E2EConfigOpt {
return func(cfg *E2EConfig) {
cfg.UseBuilder = true
}
}
func WithSSZOnly() E2EConfigOpt {
return func(cfg *E2EConfig) {
if err := os.Setenv(params.EnvNameOverrideAccept, api.OctetStreamMediaType); err != nil {
logrus.Fatal(err)
}
}
}
// E2EConfig defines the struct for all configurations needed for E2E testing.
type E2EConfig struct {
TestCheckpointSync bool
@@ -76,7 +80,6 @@ type E2EConfig struct {
UseFixedPeerIDs bool
UseValidatorCrossClient bool
UseBeaconRestApi bool
UseSSZOnly bool
UseBuilder bool
EpochsToRun uint64
Seed int64

View File

@@ -46,7 +46,6 @@ go_library(
"//api/server/structs:go_default_library",
"//beacon-chain/core/helpers:go_default_library",
"//beacon-chain/core/signing:go_default_library",
"//config/features:go_default_library",
"//config/fieldparams:go_default_library",
"//config/params:go_default_library",
"//consensus-types/primitives:go_default_library",
@@ -128,7 +127,6 @@ go_test(
"//api/server/structs:go_default_library",
"//beacon-chain/core/helpers:go_default_library",
"//beacon-chain/rpc/eth/shared/testing:go_default_library",
"//config/features:go_default_library",
"//config/params:go_default_library",
"//consensus-types/primitives:go_default_library",
"//consensus-types/validator:go_default_library",

View File

@@ -9,7 +9,6 @@ import (
"github.com/OffchainLabs/prysm/v6/api"
"github.com/OffchainLabs/prysm/v6/api/server/structs"
"github.com/OffchainLabs/prysm/v6/config/features"
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1"
"github.com/OffchainLabs/prysm/v6/testing/assert"
@@ -215,11 +214,6 @@ func TestGetBeaconBlock_SSZ_BellatrixValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBellatrixBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -262,11 +256,6 @@ func TestGetBeaconBlock_SSZ_BlindedBellatrixValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBlindedBellatrixBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -309,11 +298,6 @@ func TestGetBeaconBlock_SSZ_CapellaValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoCapellaBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -356,11 +340,6 @@ func TestGetBeaconBlock_SSZ_BlindedCapellaValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBlindedCapellaBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -403,11 +382,6 @@ func TestGetBeaconBlock_SSZ_DenebValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoDenebBeaconBlockContents()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -450,11 +424,6 @@ func TestGetBeaconBlock_SSZ_BlindedDenebValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBlindedDenebBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -497,11 +466,6 @@ func TestGetBeaconBlock_SSZ_ElectraValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoElectraBeaconBlockContents()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -544,11 +508,6 @@ func TestGetBeaconBlock_SSZ_BlindedElectraValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBlindedElectraBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -591,11 +550,6 @@ func TestGetBeaconBlock_SSZ_UnsupportedVersion(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
const slot = primitives.Slot(1)
randaoReveal := []byte{2}
graffiti := []byte{3}
@@ -625,11 +579,6 @@ func TestGetBeaconBlock_SSZ_InvalidBlindedHeader(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBellatrixBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -663,11 +612,6 @@ func TestGetBeaconBlock_SSZ_InvalidVersionHeader(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoBellatrixBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -701,11 +645,6 @@ func TestGetBeaconBlock_SSZ_GetSSZError(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
const slot = primitives.Slot(1)
randaoReveal := []byte{2}
graffiti := []byte{3}
@@ -731,11 +670,6 @@ func TestGetBeaconBlock_SSZ_Phase0Valid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoPhase0BeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)
@@ -778,11 +712,6 @@ func TestGetBeaconBlock_SSZ_AltairValid(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
resetFn := features.InitWithReset(&features.Flags{
SSZOnly: true,
})
defer resetFn()
proto := testhelpers.GenerateProtoAltairBeaconBlock()
bytes, err := proto.MarshalSSZ()
require.NoError(t, err)

View File

@@ -7,15 +7,19 @@ import (
"fmt"
"io"
"net/http"
"os"
"strings"
"github.com/OffchainLabs/prysm/v6/api"
"github.com/OffchainLabs/prysm/v6/config/features"
"github.com/OffchainLabs/prysm/v6/api/apiutil"
"github.com/OffchainLabs/prysm/v6/config/params"
"github.com/OffchainLabs/prysm/v6/network/httputil"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
type reqOption func(*http.Request)
type RestHandler interface {
Get(ctx context.Context, endpoint string, resp interface{}) error
GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error)
@@ -26,16 +30,30 @@ type RestHandler interface {
}
type BeaconApiRestHandler struct {
client http.Client
host string
client http.Client
host string
reqOverrides []reqOption
}
// NewBeaconApiRestHandler returns a RestHandler
func NewBeaconApiRestHandler(client http.Client, host string) RestHandler {
return &BeaconApiRestHandler{
brh := &BeaconApiRestHandler{
client: client,
host: host,
}
brh.appendAcceptOverride()
return brh
}
// appendAcceptOverride enables the Accept header to be customized at runtime via an environment variable.
// This is specified as an env var because it is a niche option that prysm may use for performance testing or debugging
// bug which users are unlikely to need. Using an env var keeps the set of user-facing flags cleaner.
func (c *BeaconApiRestHandler) appendAcceptOverride() {
if accept := os.Getenv(params.EnvNameOverrideAccept); accept != "" {
c.reqOverrides = append(c.reqOverrides, func(req *http.Request) {
req.Header.Set("Accept", accept)
})
}
}
// HttpClient returns the underlying HTTP client of the handler
@@ -56,7 +74,6 @@ func (c *BeaconApiRestHandler) Get(ctx context.Context, endpoint string, resp in
if err != nil {
return errors.Wrapf(err, "failed to create request for endpoint %s", url)
}
httpResp, err := c.client.Do(req)
if err != nil {
return errors.Wrapf(err, "failed to perform request for endpoint %s", url)
@@ -76,13 +93,16 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to create request for endpoint %s", url)
}
primaryAcceptType := fmt.Sprintf("%s;q=%s", api.OctetStreamMediaType, "0.95")
secondaryAcceptType := fmt.Sprintf("%s;q=%s", api.JsonMediaType, "0.9")
acceptHeaderString := fmt.Sprintf("%s,%s", primaryAcceptType, secondaryAcceptType)
if features.Get().SSZOnly {
acceptHeaderString = api.OctetStreamMediaType
}
req.Header.Set("Accept", acceptHeaderString)
for _, o := range c.reqOverrides {
o(req)
}
httpResp, err := c.client.Do(req)
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to perform request for endpoint %s", url)
@@ -92,16 +112,17 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b
return
}
}()
accept := req.Header.Get("Accept")
contentType := httpResp.Header.Get("Content-Type")
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, nil, errors.Wrapf(err, "failed to read response body for %s", httpResp.Request.URL)
}
if !strings.Contains(primaryAcceptType, contentType) {
if !apiutil.PrimaryAcceptMatches(accept, contentType) {
log.WithFields(logrus.Fields{
"primaryAcceptType": primaryAcceptType,
"secondaryAcceptType": secondaryAcceptType,
"receivedAcceptType": contentType,
"Accept": accept,
"Content-Type": contentType,
}).Debug("Server responded with non primary accept type")
}
@@ -115,10 +136,6 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b
return nil, nil, errorJson
}
if features.Get().SSZOnly && contentType != api.OctetStreamMediaType {
return nil, nil, errors.Errorf("server responded with non primary accept type %s", contentType)
}
return body, httpResp.Header, nil
}

View File

@@ -7,11 +7,13 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/OffchainLabs/prysm/v6/api"
"github.com/OffchainLabs/prysm/v6/api/server/structs"
"github.com/OffchainLabs/prysm/v6/config/params"
"github.com/OffchainLabs/prysm/v6/network/httputil"
"github.com/OffchainLabs/prysm/v6/testing/assert"
"github.com/OffchainLabs/prysm/v6/testing/require"
@@ -143,6 +145,25 @@ func TestGetSSZ(t *testing.T) {
})
}
func TestAcceptOverrideSSZ(t *testing.T) {
name := "TestAcceptOverride"
orig := os.Getenv(params.EnvNameOverrideAccept)
defer func() {
require.NoError(t, os.Setenv(params.EnvNameOverrideAccept, orig))
}()
require.NoError(t, os.Setenv(params.EnvNameOverrideAccept, name))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, name, r.Header.Get("Accept"))
w.WriteHeader(200)
_, err := w.Write([]byte("ok"))
require.NoError(t, err)
}))
defer srv.Close()
c := NewBeaconApiRestHandler(http.Client{Timeout: time.Second * 5}, srv.URL)
_, _, err := c.GetSSZ(t.Context(), "/test")
require.NoError(t, err)
}
func TestPost(t *testing.T) {
ctx := t.Context()
const endpoint = "/example/rest/api/endpoint"