From 77958022e7a6af74196855f053498d4a7172ea78 Mon Sep 17 00:00:00 2001 From: james-prysm <90280386+james-prysm@users.noreply.github.com> Date: Tue, 22 Jul 2025 11:06:51 -0500 Subject: [PATCH] removing ssz-only flag ( reverting feature) and fix accept header middleware (#15433) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * removing ssz-only flag * gaz * reverting other uses of sszonly * gaz * adding kasey and radek's suggestions * update changelog * adding test * radek advice with new headers and tests * adding logs and fixing comments * adding logs and fixing comments * gaz * Update validator/client/beacon-api/rest_handler_client.go Co-authored-by: Radosław Kapka * Update api/apiutil/header.go Co-authored-by: Radosław Kapka * Update api/apiutil/header.go Co-authored-by: Radosław Kapka * radek's comments * adding another failing case based on radek's suggestion * another unit test --------- Co-authored-by: Radosław Kapka --- api/apiutil/BUILD.bazel | 16 +- api/apiutil/header.go | 122 ++++++++++++ api/apiutil/header_test.go | 174 ++++++++++++++++++ api/server/middleware/BUILD.bazel | 1 + api/server/middleware/middleware.go | 37 +--- changelog/james-prysm_remove-ssz-only-flag.md | 7 + config/features/config.go | 5 - config/features/flags.go | 7 - config/params/testutils.go | 4 + testing/endtoend/components/validator.go | 3 - testing/endtoend/minimal_scenario_e2e_test.go | 8 +- testing/endtoend/types/BUILD.bazel | 2 + testing/endtoend/types/types.go | 17 +- validator/client/beacon-api/BUILD.bazel | 2 - .../beacon-api/get_beacon_block_test.go | 71 ------- .../client/beacon-api/rest_handler_client.go | 49 +++-- .../beacon-api/rest_handler_client_test.go | 21 +++ 17 files changed, 394 insertions(+), 152 deletions(-) create mode 100644 api/apiutil/header.go create mode 100644 api/apiutil/header_test.go create mode 100644 changelog/james-prysm_remove-ssz-only-flag.md diff --git a/api/apiutil/BUILD.bazel b/api/apiutil/BUILD.bazel index df32449e36..b9e4fc4782 100644 --- a/api/apiutil/BUILD.bazel +++ b/api/apiutil/BUILD.bazel @@ -2,18 +2,28 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", - srcs = ["common.go"], + srcs = [ + "common.go", + "header.go", + ], importpath = "github.com/OffchainLabs/prysm/v6/api/apiutil", visibility = ["//visibility:public"], - deps = ["//consensus-types/primitives:go_default_library"], + deps = [ + "//consensus-types/primitives:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", + ], ) go_test( name = "go_default_test", - srcs = ["common_test.go"], + srcs = [ + "common_test.go", + "header_test.go", + ], embed = [":go_default_library"], deps = [ "//consensus-types/primitives:go_default_library", "//testing/assert:go_default_library", + "//testing/require:go_default_library", ], ) diff --git a/api/apiutil/header.go b/api/apiutil/header.go new file mode 100644 index 0000000000..4ef6ca3cb2 --- /dev/null +++ b/api/apiutil/header.go @@ -0,0 +1,122 @@ +package apiutil + +import ( + "mime" + "sort" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" +) + +type mediaRange struct { + mt string // canonicalised media‑type, e.g. "application/json" + q float64 // quality factor (0‑1) + raw string // original string – useful for logging/debugging + spec int // 2=exact, 1=type/*, 0=*/* +} + +func parseMediaRange(field string) (mediaRange, bool) { + field = strings.TrimSpace(field) + + mt, params, err := mime.ParseMediaType(field) + if err != nil { + log.WithError(err).Debug("Failed to parse header field") + return mediaRange{}, false + } + + r := mediaRange{mt: mt, q: 1, spec: 2, raw: field} + + if qs, ok := params["q"]; ok { + v, err := strconv.ParseFloat(qs, 64) + if err != nil || v < 0 || v > 1 { + log.WithField("q", qs).Debug("Invalid quality factor (0‑1)") + return mediaRange{}, false // skip invalid entry + } + r.q = v + } + + switch { + case mt == "*/*": + r.spec = 0 + case strings.HasSuffix(mt, "/*"): + r.spec = 1 + } + return r, true +} + +func hasExplicitQ(r mediaRange) bool { + return strings.Contains(strings.ToLower(r.raw), ";q=") +} + +// ParseAccept returns media ranges sorted by q (desc) then specificity. +func ParseAccept(header string) []mediaRange { + if header == "" { + return []mediaRange{{mt: "*/*", q: 1, spec: 0, raw: "*/*"}} + } + + var out []mediaRange + for _, field := range strings.Split(header, ",") { + if r, ok := parseMediaRange(field); ok { + out = append(out, r) + } + } + + sort.SliceStable(out, func(i, j int) bool { + ei, ej := hasExplicitQ(out[i]), hasExplicitQ(out[j]) + if ei != ej { + return ei // explicit beats implicit + } + if out[i].q != out[j].q { + return out[i].q > out[j].q + } + return out[i].spec > out[j].spec + }) + return out +} + +// Matches reports whether content type is acceptable per the header. +func Matches(header, ct string) bool { + for _, r := range ParseAccept(header) { + switch { + case r.q == 0: + continue + case r.mt == "*/*": + return true + case strings.HasSuffix(r.mt, "/*"): + if strings.HasPrefix(ct, r.mt[:len(r.mt)-1]) { + return true + } + case r.mt == ct: + return true + } + } + return false +} + +// Negotiate selects the best server type according to the header. +// Returns the chosen type and true, or "", false when nothing matches. +func Negotiate(header string, serverTypes []string) (string, bool) { + for _, r := range ParseAccept(header) { + if r.q == 0 { + continue + } + for _, s := range serverTypes { + if Matches(r.mt, s) { + return s, true + } + } + } + return "", false +} + +// PrimaryAcceptMatches only checks if the first accept matches +func PrimaryAcceptMatches(header, produced string) bool { + for _, r := range ParseAccept(header) { + if r.q == 0 { + continue // explicitly unacceptable – skip + } + return Matches(r.mt, produced) + } + return false +} diff --git a/api/apiutil/header_test.go b/api/apiutil/header_test.go new file mode 100644 index 0000000000..5c074f0b60 --- /dev/null +++ b/api/apiutil/header_test.go @@ -0,0 +1,174 @@ +package apiutil + +import ( + "testing" + + "github.com/OffchainLabs/prysm/v6/testing/require" +) + +func TestParseAccept(t *testing.T) { + type want struct { + mt string + q float64 + spec int + } + + cases := []struct { + name string + header string + want []want + }{ + { + name: "empty header becomes */*;q=1", + header: "", + want: []want{{mt: "*/*", q: 1, spec: 0}}, + }, + { + name: "quality ordering then specificity", + header: "application/json;q=0.2, */*;q=0.1, application/xml;q=0.5, text/*;q=0.5", + want: []want{ + {mt: "application/xml", q: 0.5, spec: 2}, + {mt: "text/*", q: 0.5, spec: 1}, + {mt: "application/json", q: 0.2, spec: 2}, + {mt: "*/*", q: 0.1, spec: 0}, + }, + }, + { + name: "invalid pieces are skipped", + header: "text/plain; q=boom, application/json", + want: []want{{mt: "application/json", q: 1, spec: 2}}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := ParseAccept(tc.header) + gotProjected := make([]want, len(got)) + for i, g := range got { + gotProjected[i] = want{mt: g.mt, q: g.q, spec: g.spec} + } + require.DeepEqual(t, gotProjected, tc.want) + }) + } +} + +func TestMatches(t *testing.T) { + cases := []struct { + name string + accept string + ct string + matches bool + }{ + {"exact match", "application/json", "application/json", true}, + {"type wildcard", "application/*;q=0.8", "application/xml", true}, + {"global wildcard", "*/*;q=0.1", "image/png", true}, + {"explicitly unacceptable (q=0)", "text/*;q=0", "text/plain", false}, + {"no match", "image/png", "application/json", false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := Matches(tc.accept, tc.ct) + require.Equal(t, tc.matches, got) + }) + } +} + +func TestNegotiate(t *testing.T) { + cases := []struct { + name string + accept string + serverTypes []string + wantType string + ok bool + }{ + { + name: "highest quality wins", + accept: "application/json;q=0.8,application/xml;q=0.9", + serverTypes: []string{"application/json", "application/xml"}, + wantType: "application/xml", + ok: true, + }, + { + name: "wildcard matches first server type", + accept: "*/*;q=0.5", + serverTypes: []string{"application/octet-stream", "application/json"}, + wantType: "application/octet-stream", + ok: true, + }, + { + name: "no acceptable type", + accept: "image/png", + serverTypes: []string{"application/json"}, + wantType: "", + ok: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, ok := Negotiate(tc.accept, tc.serverTypes) + require.Equal(t, tc.ok, ok) + require.Equal(t, tc.wantType, got) + }) + } +} + +func TestPrimaryAcceptMatches(t *testing.T) { + tests := []struct { + name string + accept string + produced string + expect bool + }{ + { + name: "prefers json", + accept: "application/json;q=0.9,application/xml", + produced: "application/json", + expect: true, + }, + { + name: "wildcard application beats other wildcard", + accept: "application/*;q=0.2,*/*;q=0.1", + produced: "application/xml", + expect: true, + }, + { + name: "json wins", + accept: "application/xml;q=0.8,application/json;q=0.9", + produced: "application/json", + expect: true, + }, + { + name: "json loses", + accept: "application/xml;q=0.8,application/json;q=0.9,application/octet-stream;q=0.99", + produced: "application/json", + expect: false, + }, + { + name: "json wins with non q option", + accept: "application/xml;q=0.8,image/png,application/json;q=0.9", + produced: "application/json", + expect: true, + }, + { + name: "json not primary", + accept: "image/png,application/json", + produced: "application/json", + expect: false, + }, + { + name: "absent header", + accept: "", + produced: "text/plain", + expect: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := PrimaryAcceptMatches(tc.accept, tc.produced) + require.Equal(t, got, tc.expect) + }) + } +} diff --git a/api/server/middleware/BUILD.bazel b/api/server/middleware/BUILD.bazel index f355685ad7..eaca515833 100644 --- a/api/server/middleware/BUILD.bazel +++ b/api/server/middleware/BUILD.bazel @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//api:go_default_library", + "//api/apiutil:go_default_library", "@com_github_rs_cors//:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", ], diff --git a/api/server/middleware/middleware.go b/api/server/middleware/middleware.go index ebc9ebc3fe..9ab7c410c2 100644 --- a/api/server/middleware/middleware.go +++ b/api/server/middleware/middleware.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/OffchainLabs/prysm/v6/api" + "github.com/OffchainLabs/prysm/v6/api/apiutil" "github.com/rs/cors" log "github.com/sirupsen/logrus" ) @@ -74,42 +75,10 @@ func ContentTypeHandler(acceptedMediaTypes []string) Middleware { func AcceptHeaderHandler(serverAcceptedTypes []string) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - acceptHeader := r.Header.Get("Accept") - // header is optional and should skip if not provided - if acceptHeader == "" { - next.ServeHTTP(w, r) + if _, ok := apiutil.Negotiate(r.Header.Get("Accept"), serverAcceptedTypes); !ok { + http.Error(w, "Not Acceptable", http.StatusNotAcceptable) return } - - accepted := false - acceptTypes := strings.Split(acceptHeader, ",") - // follows rules defined in https://datatracker.ietf.org/doc/html/rfc2616#section-14.1 - for _, acceptType := range acceptTypes { - acceptType = strings.TrimSpace(acceptType) - if acceptType == "*/*" { - accepted = true - break - } - for _, serverAcceptedType := range serverAcceptedTypes { - if strings.HasPrefix(acceptType, serverAcceptedType) { - accepted = true - break - } - if acceptType != "/*" && strings.HasSuffix(acceptType, "/*") && strings.HasPrefix(serverAcceptedType, acceptType[:len(acceptType)-2]) { - accepted = true - break - } - } - if accepted { - break - } - } - - if !accepted { - http.Error(w, fmt.Sprintf("Not Acceptable: %s", acceptHeader), http.StatusNotAcceptable) - return - } - next.ServeHTTP(w, r) }) } diff --git a/changelog/james-prysm_remove-ssz-only-flag.md b/changelog/james-prysm_remove-ssz-only-flag.md new file mode 100644 index 0000000000..7325cc60fa --- /dev/null +++ b/changelog/james-prysm_remove-ssz-only-flag.md @@ -0,0 +1,7 @@ +### Removed + +- Partially reverting pr #15390 removing the `ssz-only` debug flag until there is a real usecase for the flag + +### Added + +- Added new PRYSM_API_OVERRIDE_ACCEPT environment variable to override ssz accept header as a replacement to flag \ No newline at end of file diff --git a/config/features/config.go b/config/features/config.go index 19a28684cc..a9b5a08e42 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -51,7 +51,6 @@ type Flags struct { EnableExperimentalAttestationPool bool // EnableExperimentalAttestationPool enables an experimental attestation pool design. DisableDutiesV2 bool // DisableDutiesV2 sets validator client to use the get Duties endpoint EnableWeb bool // EnableWeb enables the webui on the validator client - SSZOnly bool // SSZOnly forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled (useful for debugging) // Logging related toggles. DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected. EnableFullSSZDataLogging bool // Enables logging for full ssz data on rejected gossip messages @@ -340,10 +339,6 @@ func ConfigureValidator(ctx *cli.Context) error { logEnabled(EnableWebFlag) cfg.EnableWeb = true } - if ctx.Bool(SSZOnly.Name) { - logEnabled(SSZOnly) - cfg.SSZOnly = true - } cfg.KeystoreImportDebounceInterval = ctx.Duration(dynamicKeyReloadDebounceInterval.Name) Init(cfg) diff --git a/config/features/flags.go b/config/features/flags.go index ef0616d844..880092336b 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -197,12 +197,6 @@ var ( Usage: "(Work in progress): Enables the web portal for the validator client.", Value: false, } - - // SSZOnly forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled - SSZOnly = &cli.BoolFlag{ - Name: "ssz-only", - Usage: "(debug): Forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled", - } ) // devModeFlags holds list of flags that are set when development mode is on. @@ -225,7 +219,6 @@ var ValidatorFlags = append(deprecatedFlags, []cli.Flag{ EnableBeaconRESTApi, DisableDutiesV2, EnableWebFlag, - SSZOnly, }...) // E2EValidatorFlags contains a list of the validator feature flags to be tested in E2E. diff --git a/config/params/testutils.go b/config/params/testutils.go index 7c79e0aa65..320148f944 100644 --- a/config/params/testutils.go +++ b/config/params/testutils.go @@ -4,6 +4,10 @@ import ( "testing" ) +const ( + EnvNameOverrideAccept = "PRYSM_API_OVERRIDE_ACCEPT" +) + // SetupTestConfigCleanup preserves configurations allowing to modify them within tests without any // restrictions, everything is restored after the test. func SetupTestConfigCleanup(t testing.TB) { diff --git a/testing/endtoend/components/validator.go b/testing/endtoend/components/validator.go index f10e6aaa54..f2b7bf2374 100644 --- a/testing/endtoend/components/validator.go +++ b/testing/endtoend/components/validator.go @@ -248,9 +248,6 @@ func (v *ValidatorNode) Start(ctx context.Context) error { args = append(args, fmt.Sprintf("--%s=http://localhost:%d", flags.BeaconRESTApiProviderFlag.Name, beaconRestApiPort), fmt.Sprintf("--%s", features.EnableBeaconRESTApi.Name)) - if v.config.UseSSZOnly { - args = append(args, fmt.Sprintf("--%s", features.SSZOnly.Name)) - } } // Only apply e2e flags to the current branch. New flags may not exist in previous release. diff --git a/testing/endtoend/minimal_scenario_e2e_test.go b/testing/endtoend/minimal_scenario_e2e_test.go index 96ba723d1b..1d4151dd6e 100644 --- a/testing/endtoend/minimal_scenario_e2e_test.go +++ b/testing/endtoend/minimal_scenario_e2e_test.go @@ -25,14 +25,14 @@ func TestEndToEnd_MinimalConfig_Web3Signer_PersistentKeys(t *testing.T) { e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithRemoteSignerAndPersistentKeysFile()).run() } -func TestEndToEnd_MinimalConfig_ValidatorRESTApi(t *testing.T) { - e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi()).run() -} - func TestEndToEnd_MinimalConfig_ValidatorRESTApi_SSZ(t *testing.T) { e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi(), types.WithSSZOnly()).run() } +func TestEndToEnd_MinimalConfig_ValidatorRESTApi(t *testing.T) { + e2eMinimal(t, types.InitForkCfg(version.Bellatrix, version.Electra, params.E2ETestConfig()), types.WithCheckpointSync(), types.WithValidatorRESTApi()).run() +} + func TestEndToEnd_ScenarioRun_EEOffline(t *testing.T) { t.Skip("TODO(#10242) Prysm is current unable to handle an offline e2e") cfg := types.InitForkCfg(version.Bellatrix, version.Deneb, params.E2ETestConfig()) diff --git a/testing/endtoend/types/BUILD.bazel b/testing/endtoend/types/BUILD.bazel index 7efe33f10c..662a63091f 100644 --- a/testing/endtoend/types/BUILD.bazel +++ b/testing/endtoend/types/BUILD.bazel @@ -11,9 +11,11 @@ go_library( importpath = "github.com/OffchainLabs/prysm/v6/testing/endtoend/types", visibility = ["//testing/endtoend:__subpackages__"], deps = [ + "//api:go_default_library", "//config/params:go_default_library", "//consensus-types/primitives:go_default_library", "//runtime/version:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_google_grpc//:go_default_library", ], ) diff --git a/testing/endtoend/types/types.go b/testing/endtoend/types/types.go index 0b4040247f..ef263663f0 100644 --- a/testing/endtoend/types/types.go +++ b/testing/endtoend/types/types.go @@ -6,9 +6,11 @@ import ( "context" "os" + "github.com/OffchainLabs/prysm/v6/api" "github.com/OffchainLabs/prysm/v6/config/params" "github.com/OffchainLabs/prysm/v6/consensus-types/primitives" "github.com/OffchainLabs/prysm/v6/runtime/version" + "github.com/sirupsen/logrus" "google.golang.org/grpc" ) @@ -51,18 +53,20 @@ func WithValidatorRESTApi() E2EConfigOpt { } } -func WithSSZOnly() E2EConfigOpt { - return func(cfg *E2EConfig) { - cfg.UseSSZOnly = true - } -} - func WithBuilder() E2EConfigOpt { return func(cfg *E2EConfig) { cfg.UseBuilder = true } } +func WithSSZOnly() E2EConfigOpt { + return func(cfg *E2EConfig) { + if err := os.Setenv(params.EnvNameOverrideAccept, api.OctetStreamMediaType); err != nil { + logrus.Fatal(err) + } + } +} + // E2EConfig defines the struct for all configurations needed for E2E testing. type E2EConfig struct { TestCheckpointSync bool @@ -76,7 +80,6 @@ type E2EConfig struct { UseFixedPeerIDs bool UseValidatorCrossClient bool UseBeaconRestApi bool - UseSSZOnly bool UseBuilder bool EpochsToRun uint64 Seed int64 diff --git a/validator/client/beacon-api/BUILD.bazel b/validator/client/beacon-api/BUILD.bazel index 23d675cc88..2917f8cd5d 100644 --- a/validator/client/beacon-api/BUILD.bazel +++ b/validator/client/beacon-api/BUILD.bazel @@ -46,7 +46,6 @@ go_library( "//api/server/structs:go_default_library", "//beacon-chain/core/helpers:go_default_library", "//beacon-chain/core/signing:go_default_library", - "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", "//consensus-types/primitives:go_default_library", @@ -128,7 +127,6 @@ go_test( "//api/server/structs:go_default_library", "//beacon-chain/core/helpers:go_default_library", "//beacon-chain/rpc/eth/shared/testing:go_default_library", - "//config/features:go_default_library", "//config/params:go_default_library", "//consensus-types/primitives:go_default_library", "//consensus-types/validator:go_default_library", diff --git a/validator/client/beacon-api/get_beacon_block_test.go b/validator/client/beacon-api/get_beacon_block_test.go index ed691e33e7..9ce5fc2401 100644 --- a/validator/client/beacon-api/get_beacon_block_test.go +++ b/validator/client/beacon-api/get_beacon_block_test.go @@ -9,7 +9,6 @@ import ( "github.com/OffchainLabs/prysm/v6/api" "github.com/OffchainLabs/prysm/v6/api/server/structs" - "github.com/OffchainLabs/prysm/v6/config/features" "github.com/OffchainLabs/prysm/v6/consensus-types/primitives" ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1" "github.com/OffchainLabs/prysm/v6/testing/assert" @@ -215,11 +214,6 @@ func TestGetBeaconBlock_SSZ_BellatrixValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBellatrixBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -262,11 +256,6 @@ func TestGetBeaconBlock_SSZ_BlindedBellatrixValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBlindedBellatrixBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -309,11 +298,6 @@ func TestGetBeaconBlock_SSZ_CapellaValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoCapellaBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -356,11 +340,6 @@ func TestGetBeaconBlock_SSZ_BlindedCapellaValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBlindedCapellaBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -403,11 +382,6 @@ func TestGetBeaconBlock_SSZ_DenebValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoDenebBeaconBlockContents() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -450,11 +424,6 @@ func TestGetBeaconBlock_SSZ_BlindedDenebValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBlindedDenebBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -497,11 +466,6 @@ func TestGetBeaconBlock_SSZ_ElectraValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoElectraBeaconBlockContents() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -544,11 +508,6 @@ func TestGetBeaconBlock_SSZ_BlindedElectraValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBlindedElectraBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -591,11 +550,6 @@ func TestGetBeaconBlock_SSZ_UnsupportedVersion(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - const slot = primitives.Slot(1) randaoReveal := []byte{2} graffiti := []byte{3} @@ -625,11 +579,6 @@ func TestGetBeaconBlock_SSZ_InvalidBlindedHeader(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBellatrixBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -663,11 +612,6 @@ func TestGetBeaconBlock_SSZ_InvalidVersionHeader(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoBellatrixBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -701,11 +645,6 @@ func TestGetBeaconBlock_SSZ_GetSSZError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - const slot = primitives.Slot(1) randaoReveal := []byte{2} graffiti := []byte{3} @@ -731,11 +670,6 @@ func TestGetBeaconBlock_SSZ_Phase0Valid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoPhase0BeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) @@ -778,11 +712,6 @@ func TestGetBeaconBlock_SSZ_AltairValid(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - resetFn := features.InitWithReset(&features.Flags{ - SSZOnly: true, - }) - defer resetFn() - proto := testhelpers.GenerateProtoAltairBeaconBlock() bytes, err := proto.MarshalSSZ() require.NoError(t, err) diff --git a/validator/client/beacon-api/rest_handler_client.go b/validator/client/beacon-api/rest_handler_client.go index e2dae19ab3..beb0835111 100644 --- a/validator/client/beacon-api/rest_handler_client.go +++ b/validator/client/beacon-api/rest_handler_client.go @@ -7,15 +7,19 @@ import ( "fmt" "io" "net/http" + "os" "strings" "github.com/OffchainLabs/prysm/v6/api" - "github.com/OffchainLabs/prysm/v6/config/features" + "github.com/OffchainLabs/prysm/v6/api/apiutil" + "github.com/OffchainLabs/prysm/v6/config/params" "github.com/OffchainLabs/prysm/v6/network/httputil" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) +type reqOption func(*http.Request) + type RestHandler interface { Get(ctx context.Context, endpoint string, resp interface{}) error GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error) @@ -26,16 +30,30 @@ type RestHandler interface { } type BeaconApiRestHandler struct { - client http.Client - host string + client http.Client + host string + reqOverrides []reqOption } // NewBeaconApiRestHandler returns a RestHandler func NewBeaconApiRestHandler(client http.Client, host string) RestHandler { - return &BeaconApiRestHandler{ + brh := &BeaconApiRestHandler{ client: client, host: host, } + brh.appendAcceptOverride() + return brh +} + +// appendAcceptOverride enables the Accept header to be customized at runtime via an environment variable. +// This is specified as an env var because it is a niche option that prysm may use for performance testing or debugging +// bug which users are unlikely to need. Using an env var keeps the set of user-facing flags cleaner. +func (c *BeaconApiRestHandler) appendAcceptOverride() { + if accept := os.Getenv(params.EnvNameOverrideAccept); accept != "" { + c.reqOverrides = append(c.reqOverrides, func(req *http.Request) { + req.Header.Set("Accept", accept) + }) + } } // HttpClient returns the underlying HTTP client of the handler @@ -56,7 +74,6 @@ func (c *BeaconApiRestHandler) Get(ctx context.Context, endpoint string, resp in if err != nil { return errors.Wrapf(err, "failed to create request for endpoint %s", url) } - httpResp, err := c.client.Do(req) if err != nil { return errors.Wrapf(err, "failed to perform request for endpoint %s", url) @@ -76,13 +93,16 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b if err != nil { return nil, nil, errors.Wrapf(err, "failed to create request for endpoint %s", url) } + primaryAcceptType := fmt.Sprintf("%s;q=%s", api.OctetStreamMediaType, "0.95") secondaryAcceptType := fmt.Sprintf("%s;q=%s", api.JsonMediaType, "0.9") acceptHeaderString := fmt.Sprintf("%s,%s", primaryAcceptType, secondaryAcceptType) - if features.Get().SSZOnly { - acceptHeaderString = api.OctetStreamMediaType - } req.Header.Set("Accept", acceptHeaderString) + + for _, o := range c.reqOverrides { + o(req) + } + httpResp, err := c.client.Do(req) if err != nil { return nil, nil, errors.Wrapf(err, "failed to perform request for endpoint %s", url) @@ -92,16 +112,17 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b return } }() + accept := req.Header.Get("Accept") contentType := httpResp.Header.Get("Content-Type") body, err := io.ReadAll(httpResp.Body) if err != nil { return nil, nil, errors.Wrapf(err, "failed to read response body for %s", httpResp.Request.URL) } - if !strings.Contains(primaryAcceptType, contentType) { + + if !apiutil.PrimaryAcceptMatches(accept, contentType) { log.WithFields(logrus.Fields{ - "primaryAcceptType": primaryAcceptType, - "secondaryAcceptType": secondaryAcceptType, - "receivedAcceptType": contentType, + "Accept": accept, + "Content-Type": contentType, }).Debug("Server responded with non primary accept type") } @@ -115,10 +136,6 @@ func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]b return nil, nil, errorJson } - if features.Get().SSZOnly && contentType != api.OctetStreamMediaType { - return nil, nil, errors.Errorf("server responded with non primary accept type %s", contentType) - } - return body, httpResp.Header, nil } diff --git a/validator/client/beacon-api/rest_handler_client_test.go b/validator/client/beacon-api/rest_handler_client_test.go index 6928e1f1e7..5f5ecc67f4 100644 --- a/validator/client/beacon-api/rest_handler_client_test.go +++ b/validator/client/beacon-api/rest_handler_client_test.go @@ -7,11 +7,13 @@ import ( "io" "net/http" "net/http/httptest" + "os" "testing" "time" "github.com/OffchainLabs/prysm/v6/api" "github.com/OffchainLabs/prysm/v6/api/server/structs" + "github.com/OffchainLabs/prysm/v6/config/params" "github.com/OffchainLabs/prysm/v6/network/httputil" "github.com/OffchainLabs/prysm/v6/testing/assert" "github.com/OffchainLabs/prysm/v6/testing/require" @@ -143,6 +145,25 @@ func TestGetSSZ(t *testing.T) { }) } +func TestAcceptOverrideSSZ(t *testing.T) { + name := "TestAcceptOverride" + orig := os.Getenv(params.EnvNameOverrideAccept) + defer func() { + require.NoError(t, os.Setenv(params.EnvNameOverrideAccept, orig)) + }() + require.NoError(t, os.Setenv(params.EnvNameOverrideAccept, name)) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, name, r.Header.Get("Accept")) + w.WriteHeader(200) + _, err := w.Write([]byte("ok")) + require.NoError(t, err) + })) + defer srv.Close() + c := NewBeaconApiRestHandler(http.Client{Timeout: time.Second * 5}, srv.URL) + _, _, err := c.GetSSZ(t.Context(), "/test") + require.NoError(t, err) +} + func TestPost(t *testing.T) { ctx := t.Context() const endpoint = "/example/rest/api/endpoint"