radek's advice moving logic to the rest handler

This commit is contained in:
james-prysm
2026-01-05 10:53:43 -06:00
parent a4bcbb7da6
commit 489ff40f94
5 changed files with 157 additions and 64 deletions

View File

@@ -3,7 +3,6 @@ package beacon_api
import (
"context"
"net/http"
"net/url"
"strconv"
"github.com/OffchainLabs/prysm/v7/api/server/structs"
@@ -106,29 +105,14 @@ func (c *beaconApiNodeClient) Peers(ctx context.Context, in *empty.Empty) (*ethp
// IsReady returns true only if the node is fully synced (200 OK).
// A 206 Partial Content response indicates the node is syncing and not ready.
func (c *beaconApiNodeClient) IsReady(ctx context.Context) bool {
endpoint, err := url.JoinPath(c.jsonRestHandler.Host(), "/eth/v1/node/health")
if err != nil {
log.WithError(err).Error("failed to construct health URL")
return false
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
log.WithError(err).Error("failed to create health request")
return false
}
resp, err := c.jsonRestHandler.HttpClient().Do(req)
statusCode, err := c.jsonRestHandler.GetStatusCode(ctx, "/eth/v1/node/health")
if err != nil {
log.WithError(err).Error("failed to get health of node")
return false
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.WithError(err).Error("failed to close response body")
}
}()
// Only 200 OK means the node is fully synced and ready.
// 206 Partial Content means syncing, 503 means unavailable.
return resp.StatusCode == http.StatusOK
return statusCode == http.StatusOK
}
func NewNodeClientWithFallback(jsonRestHandler RestHandler, fallbackClient iface.NodeClient) iface.NodeClient {

View File

@@ -3,7 +3,6 @@ package beacon_api
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/OffchainLabs/prysm/v7/api/server/structs"
@@ -291,9 +290,12 @@ func TestGetVersion(t *testing.T) {
}
func TestIsReady(t *testing.T) {
const healthEndpoint = "/eth/v1/node/health"
testCases := []struct {
name string
statusCode int
err error
expectedResult bool
}{
{
@@ -316,6 +318,11 @@ func TestIsReady(t *testing.T) {
statusCode: http.StatusInternalServerError,
expectedResult: false,
},
{
name: "returns false on error",
err: errors.New("request failed"),
expectedResult: false,
},
}
for _, tc := range testCases {
@@ -324,15 +331,11 @@ func TestIsReady(t *testing.T) {
defer ctrl.Finish()
ctx := t.Context()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/eth/v1/node/health", r.URL.Path)
w.WriteHeader(tc.statusCode)
}))
defer server.Close()
jsonRestHandler := mock.NewMockJsonRestHandler(ctrl)
jsonRestHandler.EXPECT().Host().Return(server.URL)
jsonRestHandler.EXPECT().HttpClient().Return(server.Client())
jsonRestHandler.EXPECT().GetStatusCode(
gomock.Any(),
healthEndpoint,
).Return(tc.statusCode, tc.err)
nodeClient := &beaconApiNodeClient{jsonRestHandler: jsonRestHandler}
result := nodeClient.IsReady(ctx)

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: validator/client/beacon-api/json_rest_handler.go
// Source: validator/client/beacon-api/rest_handler_client.go
//
// Generated by this command:
//
// mockgen -package=mock -source=validator/client/beacon-api/json_rest_handler.go -destination=validator/client/beacon-api/mock/json_rest_handler_mock.go
// mockgen -package=mock -source=validator/client/beacon-api/rest_handler_client.go -destination=validator/client/beacon-api/mock/json_rest_handler_mock.go RestHandler
//
// Package mock is a generated GoMock package.
@@ -18,32 +18,37 @@ import (
gomock "go.uber.org/mock/gomock"
)
// MockJsonRestHandler is a mock of JsonRestHandler interface.
type MockJsonRestHandler struct {
// Backward compatibility aliases for the renamed mock type.
type MockJsonRestHandler = MockRestHandler
type MockJsonRestHandlerMockRecorder = MockRestHandlerMockRecorder
var NewMockJsonRestHandler = NewMockRestHandler
// MockRestHandler is a mock of RestHandler interface.
type MockRestHandler struct {
ctrl *gomock.Controller
recorder *MockJsonRestHandlerMockRecorder
isgomock struct{}
recorder *MockRestHandlerMockRecorder
}
// MockJsonRestHandlerMockRecorder is the mock recorder for MockJsonRestHandler.
type MockJsonRestHandlerMockRecorder struct {
mock *MockJsonRestHandler
// MockRestHandlerMockRecorder is the mock recorder for MockRestHandler.
type MockRestHandlerMockRecorder struct {
mock *MockRestHandler
}
// NewMockJsonRestHandler creates a new mock instance.
func NewMockJsonRestHandler(ctrl *gomock.Controller) *MockJsonRestHandler {
mock := &MockJsonRestHandler{ctrl: ctrl}
mock.recorder = &MockJsonRestHandlerMockRecorder{mock}
// NewMockRestHandler creates a new mock instance.
func NewMockRestHandler(ctrl *gomock.Controller) *MockRestHandler {
mock := &MockRestHandler{ctrl: ctrl}
mock.recorder = &MockRestHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockJsonRestHandler) EXPECT() *MockJsonRestHandlerMockRecorder {
func (m *MockRestHandler) EXPECT() *MockRestHandlerMockRecorder {
return m.recorder
}
// Get mocks base method.
func (m *MockJsonRestHandler) Get(ctx context.Context, endpoint string, resp any) error {
func (m *MockRestHandler) Get(ctx context.Context, endpoint string, resp any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", ctx, endpoint, resp)
ret0, _ := ret[0].(error)
@@ -51,13 +56,13 @@ func (m *MockJsonRestHandler) Get(ctx context.Context, endpoint string, resp any
}
// Get indicates an expected call of Get.
func (mr *MockJsonRestHandlerMockRecorder) Get(ctx, endpoint, resp any) *gomock.Call {
func (mr *MockRestHandlerMockRecorder) Get(ctx, endpoint, resp any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockJsonRestHandler)(nil).Get), ctx, endpoint, resp)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockRestHandler)(nil).Get), ctx, endpoint, resp)
}
// GetSSZ mocks base method.
func (m *MockJsonRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error) {
func (m *MockRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSSZ", ctx, endpoint)
ret0, _ := ret[0].([]byte)
@@ -67,13 +72,28 @@ func (m *MockJsonRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]by
}
// GetSSZ indicates an expected call of GetSSZ.
func (mr *MockJsonRestHandlerMockRecorder) GetSSZ(ctx, endpoint any) *gomock.Call {
func (mr *MockRestHandlerMockRecorder) GetSSZ(ctx, endpoint any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSSZ", reflect.TypeOf((*MockJsonRestHandler)(nil).GetSSZ), ctx, endpoint)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSSZ", reflect.TypeOf((*MockRestHandler)(nil).GetSSZ), ctx, endpoint)
}
// GetStatusCode mocks base method.
func (m *MockRestHandler) GetStatusCode(ctx context.Context, endpoint string) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStatusCode", ctx, endpoint)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStatusCode indicates an expected call of GetStatusCode.
func (mr *MockRestHandlerMockRecorder) GetStatusCode(ctx, endpoint any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatusCode", reflect.TypeOf((*MockRestHandler)(nil).GetStatusCode), ctx, endpoint)
}
// Host mocks base method.
func (m *MockJsonRestHandler) Host() string {
func (m *MockRestHandler) Host() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Host")
ret0, _ := ret[0].(string)
@@ -81,13 +101,13 @@ func (m *MockJsonRestHandler) Host() string {
}
// Host indicates an expected call of Host.
func (mr *MockJsonRestHandlerMockRecorder) Host() *gomock.Call {
func (mr *MockRestHandlerMockRecorder) Host() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Host", reflect.TypeOf((*MockJsonRestHandler)(nil).Host))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Host", reflect.TypeOf((*MockRestHandler)(nil).Host))
}
// HttpClient mocks base method.
func (m *MockJsonRestHandler) HttpClient() *http.Client {
func (m *MockRestHandler) HttpClient() *http.Client {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HttpClient")
ret0, _ := ret[0].(*http.Client)
@@ -95,13 +115,13 @@ func (m *MockJsonRestHandler) HttpClient() *http.Client {
}
// HttpClient indicates an expected call of HttpClient.
func (mr *MockJsonRestHandlerMockRecorder) HttpClient() *gomock.Call {
func (mr *MockRestHandlerMockRecorder) HttpClient() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HttpClient", reflect.TypeOf((*MockJsonRestHandler)(nil).HttpClient))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HttpClient", reflect.TypeOf((*MockRestHandler)(nil).HttpClient))
}
// Post mocks base method.
func (m *MockJsonRestHandler) Post(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer, resp any) error {
func (m *MockRestHandler) Post(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer, resp any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Post", ctx, endpoint, headers, data, resp)
ret0, _ := ret[0].(error)
@@ -109,13 +129,13 @@ func (m *MockJsonRestHandler) Post(ctx context.Context, endpoint string, headers
}
// Post indicates an expected call of Post.
func (mr *MockJsonRestHandlerMockRecorder) Post(ctx, endpoint, headers, data, resp any) *gomock.Call {
func (mr *MockRestHandlerMockRecorder) Post(ctx, endpoint, headers, data, resp any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockJsonRestHandler)(nil).Post), ctx, endpoint, headers, data, resp)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Post", reflect.TypeOf((*MockRestHandler)(nil).Post), ctx, endpoint, headers, data, resp)
}
// Post mocks base method.
func (m *MockJsonRestHandler) PostSSZ(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer) ([]byte, http.Header, error) {
// PostSSZ mocks base method.
func (m *MockRestHandler) PostSSZ(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer) ([]byte, http.Header, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PostSSZ", ctx, endpoint, headers, data)
ret0, _ := ret[0].([]byte)
@@ -124,20 +144,20 @@ func (m *MockJsonRestHandler) PostSSZ(ctx context.Context, endpoint string, head
return ret0, ret1, ret2
}
// Post indicates an expected call of Post.
func (mr *MockJsonRestHandlerMockRecorder) PostSSZ(ctx, endpoint, headers, data any) *gomock.Call {
// PostSSZ indicates an expected call of PostSSZ.
func (mr *MockRestHandlerMockRecorder) PostSSZ(ctx, endpoint, headers, data any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostSSZ", reflect.TypeOf((*MockJsonRestHandler)(nil).PostSSZ), ctx, endpoint, headers, data)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PostSSZ", reflect.TypeOf((*MockRestHandler)(nil).PostSSZ), ctx, endpoint, headers, data)
}
// SetHost mocks base method.
func (m *MockJsonRestHandler) SetHost(host string) {
func (m *MockRestHandler) SetHost(host string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHost", host)
}
// SetHost indicates an expected call of SetHost.
func (mr *MockJsonRestHandlerMockRecorder) SetHost(host any) *gomock.Call {
func (mr *MockRestHandlerMockRecorder) SetHost(host any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHost", reflect.TypeOf((*MockJsonRestHandler)(nil).SetHost), host)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHost", reflect.TypeOf((*MockRestHandler)(nil).SetHost), host)
}

View File

@@ -23,6 +23,7 @@ type reqOption func(*http.Request)
type RestHandler interface {
Get(ctx context.Context, endpoint string, resp any) error
GetStatusCode(ctx context.Context, endpoint string) (int, error)
GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error)
Post(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer, resp any) error
PostSSZ(ctx context.Context, endpoint string, headers map[string]string, data *bytes.Buffer) ([]byte, http.Header, error)
@@ -90,6 +91,28 @@ func (c *BeaconApiRestHandler) Get(ctx context.Context, endpoint string, resp an
return decodeResp(httpResp, resp)
}
// GetStatusCode sends a GET request and returns only the HTTP status code.
// This is useful for endpoints like /eth/v1/node/health that communicate status via HTTP codes
// (200 = ready, 206 = syncing, 503 = unavailable) rather than response bodies.
func (c *BeaconApiRestHandler) GetStatusCode(ctx context.Context, endpoint string) (int, error) {
url := c.host + endpoint
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return 0, errors.Wrapf(err, "failed to create request for endpoint %s", url)
}
req.Header.Set("User-Agent", version.BuildData())
httpResp, err := c.client.Do(req)
if err != nil {
return 0, errors.Wrapf(err, "failed to perform request for endpoint %s", url)
}
defer func() {
if err := httpResp.Body.Close(); err != nil {
return
}
}()
return httpResp.StatusCode, nil
}
func (c *BeaconApiRestHandler) GetSSZ(ctx context.Context, endpoint string) ([]byte, http.Header, error) {
url := c.host + endpoint
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)

View File

@@ -359,3 +359,66 @@ func Test_decodeResp(t *testing.T) {
assert.ErrorContains(t, "HTTP request unsuccessful (500: foo)", err)
})
}
func TestGetStatusCode(t *testing.T) {
ctx := t.Context()
const endpoint = "/eth/v1/node/health"
testCases := []struct {
name string
serverStatusCode int
expectedStatusCode int
}{
{
name: "returns 200 OK",
serverStatusCode: http.StatusOK,
expectedStatusCode: http.StatusOK,
},
{
name: "returns 206 Partial Content",
serverStatusCode: http.StatusPartialContent,
expectedStatusCode: http.StatusPartialContent,
},
{
name: "returns 503 Service Unavailable",
serverStatusCode: http.StatusServiceUnavailable,
expectedStatusCode: http.StatusServiceUnavailable,
},
{
name: "returns 500 Internal Server Error",
serverStatusCode: http.StatusInternalServerError,
expectedStatusCode: http.StatusInternalServerError,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, version.BuildData(), r.Header.Get("User-Agent"))
w.WriteHeader(tc.serverStatusCode)
})
server := httptest.NewServer(mux)
defer server.Close()
jsonRestHandler := BeaconApiRestHandler{
client: http.Client{Timeout: time.Second * 5},
host: server.URL,
}
statusCode, err := jsonRestHandler.GetStatusCode(ctx, endpoint)
require.NoError(t, err)
assert.Equal(t, tc.expectedStatusCode, statusCode)
})
}
t.Run("returns error on connection failure", func(t *testing.T) {
jsonRestHandler := BeaconApiRestHandler{
client: http.Client{Timeout: time.Millisecond * 100},
host: "http://localhost:99999", // Invalid port
}
_, err := jsonRestHandler.GetStatusCode(ctx, endpoint)
require.ErrorContains(t, "failed to perform request", err)
})
}