Safe StreamEvents write loop (#14557)

* new type for tests where errors are only logged

* StreamHandler waits for write loop exit

* add test case for writer timeout

* add changelog

* add missing file

* logging fix

* fix logging test to allow info logs

* naming/comments; make response controller private

* simplify cancel defers

* fix typo in test file name

---------

Co-authored-by: Kasey Kirkham <kasey@users.noreply.github.com>
This commit is contained in:
kasey
2024-10-24 14:16:17 -05:00
committed by GitHub
parent 83ed320826
commit 52cf3a155d
10 changed files with 485 additions and 94 deletions

View File

@@ -42,6 +42,7 @@ The format is based on Keep a Changelog, and this project adheres to Semantic Ve
- Fixed mesh size by appending `gParams.Dhi = gossipSubDhi`
- Fix skipping partial withdrawals count.
- wait for the async StreamEvent writer to exit before leaving the http handler, avoiding race condition panics [pr](https://github.com/prysmaticlabs/prysm/pull/14557)
### Security

View File

@@ -4,6 +4,7 @@ go_library(
name = "go_default_library",
srcs = [
"events.go",
"log.go",
"server.go",
],
importpath = "github.com/prysmaticlabs/prysm/v5/beacon-chain/rpc/eth/events",
@@ -58,5 +59,6 @@ go_test(
"//testing/util:go_default_library",
"@com_github_ethereum_go_ethereum//common:go_default_library",
"@com_github_r3labs_sse_v2//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
],
)

View File

@@ -28,7 +28,6 @@ import (
eth "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/v5/runtime/version"
"github.com/prysmaticlabs/prysm/v5/time/slots"
log "github.com/sirupsen/logrus"
)
const DefaultEventFeedDepth = 1000
@@ -74,13 +73,6 @@ var (
errWriterUnusable = errors.New("http response writer is unusable")
)
// StreamingResponseWriter defines a type that can be used by the eventStreamer.
// This must be an http.ResponseWriter that supports flushing and hijacking.
type StreamingResponseWriter interface {
http.ResponseWriter
http.Flusher
}
// The eventStreamer uses lazyReaders to defer serialization until the moment the value is ready to be written to the client.
type lazyReader func() io.Reader
@@ -150,6 +142,7 @@ func newTopicRequest(topics []string) (*topicRequest, error) {
// Servers may send SSE comments beginning with ':' for any purpose,
// including to keep the event stream connection alive in the presence of proxy servers.
func (s *Server) StreamEvents(w http.ResponseWriter, r *http.Request) {
log.Debug("Starting StreamEvents handler")
ctx, span := trace.StartSpan(r.Context(), "events.StreamEvents")
defer span.End()
@@ -159,47 +152,51 @@ func (s *Server) StreamEvents(w http.ResponseWriter, r *http.Request) {
return
}
sw, ok := w.(StreamingResponseWriter)
if !ok {
msg := "beacon node misconfiguration: http stack may not support required response handling features, like flushing"
httputil.HandleError(w, msg, http.StatusInternalServerError)
return
timeout := s.EventWriteTimeout
if timeout == 0 {
timeout = time.Duration(params.BeaconConfig().SecondsPerSlot) * time.Second
}
depth := s.EventFeedDepth
if depth == 0 {
depth = DefaultEventFeedDepth
ka := s.KeepAliveInterval
if ka == 0 {
ka = timeout
}
es, err := newEventStreamer(depth, s.KeepAliveInterval)
if err != nil {
httputil.HandleError(w, err.Error(), http.StatusInternalServerError)
return
buffSize := s.EventFeedDepth
if buffSize == 0 {
buffSize = DefaultEventFeedDepth
}
api.SetSSEHeaders(w)
sw := newStreamingResponseController(w, timeout)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
api.SetSSEHeaders(sw)
es := newEventStreamer(buffSize, ka)
go es.outboxWriteLoop(ctx, cancel, sw)
if err := es.recvEventLoop(ctx, cancel, topics, s); err != nil {
log.WithError(err).Debug("Shutting down StreamEvents handler.")
}
cleanupStart := time.Now()
es.waitForExit()
log.WithField("cleanup_wait", time.Since(cleanupStart)).Debug("streamEvents shutdown complete")
}
func newEventStreamer(buffSize int, ka time.Duration) (*eventStreamer, error) {
if ka == 0 {
ka = time.Duration(params.BeaconConfig().SecondsPerSlot) * time.Second
}
func newEventStreamer(buffSize int, ka time.Duration) *eventStreamer {
return &eventStreamer{
outbox: make(chan lazyReader, buffSize),
keepAlive: ka,
}, nil
outbox: make(chan lazyReader, buffSize),
keepAlive: ka,
openUntilExit: make(chan struct{}),
}
}
type eventStreamer struct {
outbox chan lazyReader
keepAlive time.Duration
outbox chan lazyReader
keepAlive time.Duration
openUntilExit chan struct{}
}
func (es *eventStreamer) recvEventLoop(ctx context.Context, cancel context.CancelFunc, req *topicRequest, s *Server) error {
defer close(es.outbox)
defer cancel()
eventsChan := make(chan *feed.Event, len(es.outbox))
if req.needOpsFeed {
opsSub := s.OperationNotifier.OperationFeed().Subscribe(eventsChan)
@@ -228,7 +225,6 @@ func (es *eventStreamer) recvEventLoop(ctx context.Context, cancel context.Cance
// channel should stay relatively empty, which gives this loop time to unsubscribe
// and cleanup before the event stream channel fills and disrupts other readers.
if err := es.safeWrite(ctx, lr); err != nil {
cancel()
// note: we could hijack the connection and close it here. Does that cause issues? What are the benefits?
// A benefit of hijack and close is that it may force an error on the remote end, however just closing the context of the
// http handler may be sufficient to cause the remote http response reader to close.
@@ -265,12 +261,13 @@ func newlineReader() io.Reader {
// outboxWriteLoop runs in a separate goroutine. Its job is to write the values in the outbox to
// the client as fast as the client can read them.
func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.CancelFunc, w StreamingResponseWriter) {
func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.CancelFunc, w *streamingResponseWriterController) {
var err error
defer func() {
if err != nil {
log.WithError(err).Debug("Event streamer shutting down due to error.")
}
es.exit()
}()
defer cancel()
// Write a keepalive at the start to test the connection and simplify test setup.
@@ -310,18 +307,43 @@ func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.Can
}
}
func writeLazyReaderWithRecover(w StreamingResponseWriter, lr lazyReader) (err error) {
func (es *eventStreamer) exit() {
drained := 0
for range es.outbox {
drained += 1
}
log.WithField("undelivered_events", drained).Debug("Event stream outbox drained.")
close(es.openUntilExit)
}
// waitForExit blocks until the outboxWriteLoop has exited.
// While this function blocks, it is not yet safe to exit the http handler,
// because the outboxWriteLoop may still be writing to the http ResponseWriter.
func (es *eventStreamer) waitForExit() {
<-es.openUntilExit
}
func writeLazyReaderWithRecover(w *streamingResponseWriterController, lr lazyReader) (err error) {
defer func() {
if r := recover(); r != nil {
log.WithField("panic", r).Error("Recovered from panic while writing event to client.")
err = errWriterUnusable
}
}()
_, err = io.Copy(w, lr())
r := lr()
out, err := io.ReadAll(r)
if err != nil {
return err
}
_, err = w.Write(out)
return err
}
func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWriter, first lazyReader) error {
func (es *eventStreamer) writeOutbox(ctx context.Context, w *streamingResponseWriterController, first lazyReader) error {
// The outboxWriteLoop is responsible for managing the keep-alive timer and toggling between reading from the outbox
// when it is ready, only allowing the keep-alive to fire when there hasn't been a write in the keep-alive interval.
// Since outboxWriteLoop will get either the first event or the keep-alive, we let it pass in the first event to write,
// either the event's lazyReader, or nil for a keep-alive.
needKeepAlive := true
if first != nil {
if err := writeLazyReaderWithRecover(w, first); err != nil {
@@ -337,6 +359,11 @@ func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWri
case <-ctx.Done():
return ctx.Err()
case rf := <-es.outbox:
// We don't want to call Flush until we've exhausted all the writes - it's always preferrable to
// just keep draining the outbox and rely on the underlying Write code to flush+block when it
// needs to based on buffering. Whenever we fill the buffer with a string of writes, the underlying
// code will flush on its own, so it's better to explicitly flush only once, after we've totally
// drained the outbox, to catch any dangling bytes stuck in a buffer.
if err := writeLazyReaderWithRecover(w, rf); err != nil {
return err
}
@@ -347,8 +374,7 @@ func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWri
return err
}
}
w.Flush()
return nil
return w.Flush()
}
}
}
@@ -638,3 +664,51 @@ func (s *Server) currentPayloadAttributes(ctx context.Context) (lazyReader, erro
})
}, nil
}
func newStreamingResponseController(rw http.ResponseWriter, timeout time.Duration) *streamingResponseWriterController {
rc := http.NewResponseController(rw)
return &streamingResponseWriterController{
timeout: timeout,
rw: rw,
rc: rc,
}
}
// streamingResponseWriterController provides an interface similar to an http.ResponseWriter,
// wrapping an http.ResponseWriter and an http.ResponseController, using the ResponseController
// to set and clear deadlines for Write and Flush methods, and delegating to the underlying
// types to Write and Flush.
type streamingResponseWriterController struct {
timeout time.Duration
rw http.ResponseWriter
rc *http.ResponseController
}
func (c *streamingResponseWriterController) Write(b []byte) (int, error) {
if err := c.setDeadline(); err != nil {
return 0, err
}
out, err := c.rw.Write(b)
if err != nil {
return out, err
}
return out, c.clearDeadline()
}
func (c *streamingResponseWriterController) setDeadline() error {
return c.rc.SetWriteDeadline(time.Now().Add(c.timeout))
}
func (c *streamingResponseWriterController) clearDeadline() error {
return c.rc.SetWriteDeadline(time.Time{})
}
func (c *streamingResponseWriterController) Flush() error {
if err := c.setDeadline(); err != nil {
return err
}
if err := c.rc.Flush(); err != nil {
return err
}
return c.clearDeadline()
}

View File

@@ -27,9 +27,12 @@ import (
"github.com/prysmaticlabs/prysm/v5/testing/require"
"github.com/prysmaticlabs/prysm/v5/testing/util"
sse "github.com/r3labs/sse/v2"
"github.com/sirupsen/logrus"
)
func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper, events []*feed.Event, req *topicRequest, s *Server, w *StreamingResponseWriterRecorder) {
var testEventWriteTimeout = 100 * time.Millisecond
func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper, events []*feed.Event, req *topicRequest, s *Server, w *StreamingResponseWriterRecorder, logs chan *logrus.Entry) {
// maxBufferSize param copied from sse lib client code
sseR := sse.NewEventStreamReader(w.Body(), 1<<24)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
@@ -77,21 +80,29 @@ func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper
}
}
}()
select {
case <-done:
break
case <-ctx.Done():
t.Fatalf("context canceled / timed out waiting for events, err=%v", ctx.Err())
for {
select {
case entry := <-logs:
errAttr, ok := entry.Data[logrus.ErrorKey]
if ok {
t.Errorf("unexpected error in logs: %v", errAttr)
}
case <-done:
require.Equal(t, 0, len(expected), "expected events not seen")
return
case <-ctx.Done():
t.Fatalf("context canceled / timed out waiting for events, err=%v", ctx.Err())
}
}
require.Equal(t, 0, len(expected), "expected events not seen")
}
func (tr *topicRequest) testHttpRequest(_ *testing.T) *http.Request {
func (tr *topicRequest) testHttpRequest(ctx context.Context, _ *testing.T) *http.Request {
tq := make([]string, 0, len(tr.topics))
for topic := range tr.topics {
tq = append(tq, "topics="+topic)
}
return httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://example.com/eth/v1/events?%s", strings.Join(tq, "&")), nil)
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://example.com/eth/v1/events?%s", strings.Join(tq, "&")), nil)
return req.WithContext(ctx)
}
func operationEventsFixtures(t *testing.T) (*topicRequest, []*feed.Event) {
@@ -235,31 +246,77 @@ func operationEventsFixtures(t *testing.T) (*topicRequest, []*feed.Event) {
}
}
type streamTestSync struct {
done chan struct{}
cancel func()
undo func()
logs chan *logrus.Entry
ctx context.Context
t *testing.T
}
func (s *streamTestSync) cleanup() {
s.cancel()
select {
case <-s.done:
case <-time.After(10 * time.Millisecond):
s.t.Fatal("timed out waiting for handler to finish")
}
s.undo()
}
func (s *streamTestSync) markDone() {
close(s.done)
}
func newStreamTestSync(t *testing.T) *streamTestSync {
logChan := make(chan *logrus.Entry, 100)
cew := util.NewChannelEntryWriter(logChan)
undo := util.RegisterHookWithUndo(logger, cew)
ctx, cancel := context.WithCancel(context.Background())
return &streamTestSync{
t: t,
ctx: ctx,
cancel: cancel,
logs: logChan,
undo: undo,
done: make(chan struct{}),
}
}
func TestStreamEvents_OperationsEvents(t *testing.T) {
t.Run("operations", func(t *testing.T) {
testSync := newStreamTestSync(t)
defer testSync.cleanup()
stn := mockChain.NewEventFeedWrapper()
opn := mockChain.NewEventFeedWrapper()
s := &Server{
StateNotifier: &mockChain.SimpleNotifier{Feed: stn},
OperationNotifier: &mockChain.SimpleNotifier{Feed: opn},
EventWriteTimeout: testEventWriteTimeout,
}
topics, events := operationEventsFixtures(t)
request := topics.testHttpRequest(t)
w := NewStreamingResponseWriterRecorder()
request := topics.testHttpRequest(testSync.ctx, t)
w := NewStreamingResponseWriterRecorder(testSync.ctx)
go func() {
s.StreamEvents(w, request)
testSync.markDone()
}()
requireAllEventsReceived(t, stn, opn, events, topics, s, w)
requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs)
})
t.Run("state", func(t *testing.T) {
testSync := newStreamTestSync(t)
defer testSync.cleanup()
stn := mockChain.NewEventFeedWrapper()
opn := mockChain.NewEventFeedWrapper()
s := &Server{
StateNotifier: &mockChain.SimpleNotifier{Feed: stn},
OperationNotifier: &mockChain.SimpleNotifier{Feed: opn},
EventWriteTimeout: testEventWriteTimeout,
}
topics, err := newTopicRequest([]string{
@@ -269,8 +326,8 @@ func TestStreamEvents_OperationsEvents(t *testing.T) {
BlockTopic,
})
require.NoError(t, err)
request := topics.testHttpRequest(t)
w := NewStreamingResponseWriterRecorder()
request := topics.testHttpRequest(testSync.ctx, t)
w := NewStreamingResponseWriterRecorder(testSync.ctx)
b, err := blocks.NewSignedBeaconBlock(util.HydrateSignedBeaconBlock(&eth.SignedBeaconBlock{}))
require.NoError(t, err)
@@ -323,9 +380,10 @@ func TestStreamEvents_OperationsEvents(t *testing.T) {
go func() {
s.StreamEvents(w, request)
testSync.markDone()
}()
requireAllEventsReceived(t, stn, opn, events, topics, s, w)
requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs)
})
t.Run("payload attributes", func(t *testing.T) {
type testCase struct {
@@ -396,59 +454,93 @@ func TestStreamEvents_OperationsEvents(t *testing.T) {
},
}
for _, tc := range testCases {
st := tc.getState()
v := &eth.Validator{ExitEpoch: math.MaxUint64}
require.NoError(t, st.SetValidators([]*eth.Validator{v}))
currentSlot := primitives.Slot(0)
// to avoid slot processing
require.NoError(t, st.SetSlot(currentSlot+1))
b := tc.getBlock()
mockChainService := &mockChain.ChainService{
Root: make([]byte, 32),
State: st,
Block: b,
Slot: &currentSlot,
}
t.Run(tc.name, func(t *testing.T) {
testSync := newStreamTestSync(t)
defer testSync.cleanup()
stn := mockChain.NewEventFeedWrapper()
opn := mockChain.NewEventFeedWrapper()
s := &Server{
StateNotifier: &mockChain.SimpleNotifier{Feed: stn},
OperationNotifier: &mockChain.SimpleNotifier{Feed: opn},
HeadFetcher: mockChainService,
ChainInfoFetcher: mockChainService,
TrackedValidatorsCache: cache.NewTrackedValidatorsCache(),
}
if tc.SetTrackedValidatorsCache != nil {
tc.SetTrackedValidatorsCache(s.TrackedValidatorsCache)
}
topics, err := newTopicRequest([]string{PayloadAttributesTopic})
require.NoError(t, err)
request := topics.testHttpRequest(t)
w := NewStreamingResponseWriterRecorder()
events := []*feed.Event{&feed.Event{Type: statefeed.MissedSlot}}
st := tc.getState()
v := &eth.Validator{ExitEpoch: math.MaxUint64}
require.NoError(t, st.SetValidators([]*eth.Validator{v}))
currentSlot := primitives.Slot(0)
// to avoid slot processing
require.NoError(t, st.SetSlot(currentSlot+1))
b := tc.getBlock()
mockChainService := &mockChain.ChainService{
Root: make([]byte, 32),
State: st,
Block: b,
Slot: &currentSlot,
}
go func() {
s.StreamEvents(w, request)
}()
requireAllEventsReceived(t, stn, opn, events, topics, s, w)
stn := mockChain.NewEventFeedWrapper()
opn := mockChain.NewEventFeedWrapper()
s := &Server{
StateNotifier: &mockChain.SimpleNotifier{Feed: stn},
OperationNotifier: &mockChain.SimpleNotifier{Feed: opn},
HeadFetcher: mockChainService,
ChainInfoFetcher: mockChainService,
TrackedValidatorsCache: cache.NewTrackedValidatorsCache(),
EventWriteTimeout: testEventWriteTimeout,
}
if tc.SetTrackedValidatorsCache != nil {
tc.SetTrackedValidatorsCache(s.TrackedValidatorsCache)
}
topics, err := newTopicRequest([]string{PayloadAttributesTopic})
require.NoError(t, err)
request := topics.testHttpRequest(testSync.ctx, t)
w := NewStreamingResponseWriterRecorder(testSync.ctx)
events := []*feed.Event{&feed.Event{Type: statefeed.MissedSlot}}
go func() {
s.StreamEvents(w, request)
testSync.markDone()
}()
requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs)
})
}
})
}
func TestStuckReader(t *testing.T) {
func TestStuckReaderScenarios(t *testing.T) {
cases := []struct {
name string
queueDepth func([]*feed.Event) int
}{
{
name: "slow reader - queue overflows",
queueDepth: func(events []*feed.Event) int {
return len(events) - 1
},
},
{
name: "slow reader - all queued, but writer is stuck, write timeout",
queueDepth: func(events []*feed.Event) int {
return len(events) + 1
},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
wedgedWriterTestCase(t, c.queueDepth)
})
}
}
func wedgedWriterTestCase(t *testing.T, queueDepth func([]*feed.Event) int) {
topics, events := operationEventsFixtures(t)
require.Equal(t, 8, len(events))
// set eventFeedDepth to a number lower than the events we intend to send to force the server to drop the reader.
stn := mockChain.NewEventFeedWrapper()
opn := mockChain.NewEventFeedWrapper()
s := &Server{
EventWriteTimeout: 10 * time.Millisecond,
StateNotifier: &mockChain.SimpleNotifier{Feed: stn},
OperationNotifier: &mockChain.SimpleNotifier{Feed: opn},
EventFeedDepth: len(events) - 1,
EventFeedDepth: queueDepth(events),
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
eventsWritten := make(chan struct{})
go func() {
@@ -468,8 +560,8 @@ func TestStuckReader(t *testing.T) {
close(eventsWritten)
}()
request := topics.testHttpRequest(t)
w := NewStreamingResponseWriterRecorder()
request := topics.testHttpRequest(ctx, t)
w := NewStreamingResponseWriterRecorder(ctx)
handlerFinished := make(chan struct{})
go func() {

View File

@@ -1,10 +1,12 @@
package events
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/prysmaticlabs/prysm/v5/testing/require"
)
@@ -17,32 +19,66 @@ type StreamingResponseWriterRecorder struct {
status chan int
bodyRecording []byte
flushed bool
writeDeadline time.Time
ctx context.Context
}
func (w *StreamingResponseWriterRecorder) StatusChan() chan int {
return w.status
}
func NewStreamingResponseWriterRecorder() *StreamingResponseWriterRecorder {
func NewStreamingResponseWriterRecorder(ctx context.Context) *StreamingResponseWriterRecorder {
r, w := io.Pipe()
return &StreamingResponseWriterRecorder{
ResponseWriter: httptest.NewRecorder(),
r: r,
w: w,
status: make(chan int, 1),
ctx: ctx,
}
}
// Write implements http.ResponseWriter.
func (w *StreamingResponseWriterRecorder) Write(data []byte) (int, error) {
w.WriteHeader(http.StatusOK)
n, err := w.w.Write(data)
written, err := writeWithDeadline(w.ctx, w.w, data, w.writeDeadline)
if err != nil {
return n, err
return written, err
}
// The test response writer is non-blocking.
return w.ResponseWriter.Write(data)
}
var zeroTimeValue = time.Time{}
func writeWithDeadline(ctx context.Context, w io.Writer, data []byte, deadline time.Time) (int, error) {
result := struct {
written int
err error
}{}
done := make(chan struct{})
go func() {
defer close(done)
result.written, result.err = w.Write(data)
}()
if deadline == zeroTimeValue {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-done:
return result.written, result.err
}
}
select {
case <-time.After(time.Until(deadline)):
return 0, http.ErrHandlerTimeout
case <-done:
return result.written, result.err
case <-ctx.Done():
return 0, ctx.Err()
}
}
// WriteHeader implements http.ResponseWriter.
func (w *StreamingResponseWriterRecorder) WriteHeader(statusCode int) {
if w.statusWritten != nil {
@@ -65,6 +101,7 @@ func (w *StreamingResponseWriterRecorder) RequireStatus(t *testing.T, status int
}
func (w *StreamingResponseWriterRecorder) Flush() {
w.WriteHeader(200)
fw, ok := w.ResponseWriter.(http.Flusher)
if ok {
fw.Flush()
@@ -72,4 +109,9 @@ func (w *StreamingResponseWriterRecorder) Flush() {
w.flushed = true
}
func (w *StreamingResponseWriterRecorder) SetWriteDeadline(d time.Time) error {
w.writeDeadline = d
return nil
}
var _ http.ResponseWriter = &StreamingResponseWriterRecorder{}

View File

@@ -0,0 +1,6 @@
package events
import "github.com/sirupsen/logrus"
var logger = logrus.StandardLogger()
var log = logger.WithField("prefix", "events")

View File

@@ -22,4 +22,5 @@ type Server struct {
TrackedValidatorsCache *cache.TrackedValidatorsCache
KeepAliveInterval time.Duration
EventFeedDepth int
EventWriteTimeout time.Duration
}

View File

@@ -21,6 +21,7 @@ go_library(
"electra_state.go",
"helpers.go",
"lightclient.go",
"logging.go",
"merge.go",
"state.go",
"sync_aggregate.go",
@@ -69,6 +70,7 @@ go_library(
"@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
"@io_bazel_rules_go//go/tools/bazel:go_default_library",
],
)
@@ -83,6 +85,7 @@ go_test(
"deneb_test.go",
"deposits_test.go",
"helpers_test.go",
"logging_test.go",
"state_test.go",
],
embed = [":go_default_library"],
@@ -106,6 +109,8 @@ go_test(
"//testing/assert:go_default_library",
"//testing/require:go_default_library",
"//time/slots:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

90
testing/util/logging.go Normal file
View File

@@ -0,0 +1,90 @@
package util
import (
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
// ComparableHook is an interface that allows hooks to be uniquely identified
// so that tests can safely unregister them as part of cleanup.
type ComparableHook interface {
logrus.Hook
Equal(other logrus.Hook) bool
}
// UnregisterHook removes a hook that implements the HookIdentifier interface
// from all levels of the given logger.
func UnregisterHook(logger *logrus.Logger, unregister ComparableHook) {
found := false
replace := make(logrus.LevelHooks)
for lvl, hooks := range logger.Hooks {
for _, h := range hooks {
if unregister.Equal(h) {
found = true
continue
}
replace[lvl] = append(replace[lvl], h)
}
}
if !found {
return
}
logger.ReplaceHooks(replace)
}
var highestLevel logrus.Level
// RegisterHookWithUndo adds a hook to the logger and
// returns a function that can be called to remove it. This is intended to be used in tests
// to ensure that test hooks are removed after the test is complete.
func RegisterHookWithUndo(logger *logrus.Logger, hook ComparableHook) func() {
level := logger.Level
logger.AddHook(hook)
// set level to highest possible to ensure that hook is called for all log levels
logger.SetLevel(highestLevel)
return func() {
UnregisterHook(logger, hook)
logger.SetLevel(level)
}
}
// NewChannelEntryWriter creates a new ChannelEntryWriter.
// The channel argument will be sent all log entries.
// Note that if this is an unbuffered channel, it is the responsibility
// of the code using it to make sure that it is drained appropriately,
// or calls to the logger can block.
func NewChannelEntryWriter(c chan *logrus.Entry) *ChannelEntryWriter {
return &ChannelEntryWriter{c: c}
}
// ChannelEntryWriter embeds/wraps the test.Hook struct
// and adds a channel to receive log entries every time the
// Fire method of the Hook interface is called.
type ChannelEntryWriter struct {
test.Hook
c chan *logrus.Entry
}
// Fire delegates to the embedded test.Hook Fire method after
// sending the log entry to the channel.
func (c *ChannelEntryWriter) Fire(e *logrus.Entry) error {
if c.c != nil {
c.c <- e
}
return c.Hook.Fire(e)
}
func (c *ChannelEntryWriter) Equal(other logrus.Hook) bool {
return c == other
}
var _ logrus.Hook = &ChannelEntryWriter{}
var _ ComparableHook = &ChannelEntryWriter{}
func init() {
for _, level := range logrus.AllLevels {
if level > highestLevel {
highestLevel = level
}
}
}

View File

@@ -0,0 +1,78 @@
package util
import (
"testing"
"time"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/v5/testing/require"
"github.com/sirupsen/logrus"
)
func TestUnregister(t *testing.T) {
logger := logrus.New()
logger.SetLevel(logrus.PanicLevel) // set to lowest log level to test level override in
assertNoHooks(t, logger)
c := make(chan *logrus.Entry, 1)
tl := NewChannelEntryWriter(c)
undo := RegisterHookWithUndo(logger, tl)
assertRegistered(t, logger, tl)
logger.Trace("test")
select {
case <-c:
default:
t.Fatalf("Expected log entry, got none")
}
undo()
assertNoHooks(t, logger)
require.Equal(t, logrus.PanicLevel, logger.Level)
}
var logTestErr = errors.New("test")
func TestChannelEntryWriter(t *testing.T) {
logger := logrus.New()
c := make(chan *logrus.Entry)
tl := NewChannelEntryWriter(c)
logger.AddHook(tl)
msg := "test"
go func() {
logger.WithError(logTestErr).Info(msg)
}()
select {
case e := <-c:
gotErr := e.Data[logrus.ErrorKey]
if gotErr == nil {
t.Fatalf("Expected error in log entry, got nil")
}
ge, ok := gotErr.(error)
require.Equal(t, true, ok, "Expected error in log entry to be of type error, got %T", gotErr)
require.ErrorIs(t, ge, logTestErr)
require.Equal(t, msg, e.Message)
require.Equal(t, logrus.InfoLevel, e.Level)
case <-time.After(10 * time.Millisecond):
t.Fatalf("Timed out waiting for log entry")
}
}
func assertNoHooks(t *testing.T, logger *logrus.Logger) {
for lvl, hooks := range logger.Hooks {
for _, hook := range hooks {
t.Fatalf("Expected no hooks, got %v at level %s", hook, lvl.String())
}
}
}
func assertRegistered(t *testing.T, logger *logrus.Logger, hook ComparableHook) {
for _, lvl := range hook.Levels() {
registered := logger.Hooks[lvl]
found := false
for _, h := range registered {
if hook.Equal(h) {
found = true
break
}
}
require.Equal(t, true, found, "Expected hook %v to be registered at level %s, but it was not", hook, lvl.String())
}
}