diff --git a/validator/client/runner_test.go b/validator/client/runner_test.go index b617cb033f..3b4aedbb8d 100644 --- a/validator/client/runner_test.go +++ b/validator/client/runner_test.go @@ -181,7 +181,8 @@ func TestAttests_NextSlot(t *testing.T) { node.EXPECT().IsHealthy(gomock.Any()).Return(true).AnyTimes() // avoid race condition between the cancellation of the context in the go stream from slot and the setting of IsHealthy _ = tracker.CheckHealth(context.Background()) - v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker} + attSubmitted := make(chan interface{}) + v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker, AttSubmitted: attSubmitted} ctx, cancel := context.WithCancel(context.Background()) slot := primitives.Slot(55) @@ -193,9 +194,8 @@ func TestAttests_NextSlot(t *testing.T) { cancel() }() - timer := time.NewTimer(200 * time.Millisecond) run(ctx, v) - <-timer.C + <-attSubmitted require.Equal(t, true, v.AttestToBlockHeadCalled, "SubmitAttestation(%d) was not called", slot) assert.Equal(t, uint64(slot), v.AttestToBlockHeadArg1, "SubmitAttestation was called with wrong arg") } @@ -208,7 +208,8 @@ func TestProposes_NextSlot(t *testing.T) { node.EXPECT().IsHealthy(gomock.Any()).Return(true).AnyTimes() // avoid race condition between the cancellation of the context in the go stream from slot and the setting of IsHealthy _ = tracker.CheckHealth(context.Background()) - v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker} + blockProposed := make(chan interface{}) + v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker, BlockProposed: blockProposed} ctx, cancel := context.WithCancel(context.Background()) slot := primitives.Slot(55) @@ -220,9 +221,9 @@ func TestProposes_NextSlot(t *testing.T) { cancel() }() - timer := time.NewTimer(200 * time.Millisecond) run(ctx, v) - <-timer.C + <-blockProposed + require.Equal(t, true, v.ProposeBlockCalled, "ProposeBlock(%d) was not called", slot) assert.Equal(t, uint64(slot), v.ProposeBlockArg1, "ProposeBlock was called with wrong arg") } @@ -235,7 +236,9 @@ func TestBothProposesAndAttests_NextSlot(t *testing.T) { node.EXPECT().IsHealthy(gomock.Any()).Return(true).AnyTimes() // avoid race condition between the cancellation of the context in the go stream from slot and the setting of IsHealthy _ = tracker.CheckHealth(context.Background()) - v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker} + blockProposed := make(chan interface{}) + attSubmitted := make(chan interface{}) + v := &testutil.FakeValidator{Km: &mockKeymanager{accountsChangedFeed: &event.Feed{}}, Tracker: tracker, BlockProposed: blockProposed, AttSubmitted: attSubmitted} ctx, cancel := context.WithCancel(context.Background()) slot := primitives.Slot(55) @@ -247,9 +250,9 @@ func TestBothProposesAndAttests_NextSlot(t *testing.T) { cancel() }() - timer := time.NewTimer(200 * time.Millisecond) run(ctx, v) - <-timer.C + <-blockProposed + <-attSubmitted require.Equal(t, true, v.AttestToBlockHeadCalled, "SubmitAttestation(%d) was not called", slot) assert.Equal(t, uint64(slot), v.AttestToBlockHeadArg1, "SubmitAttestation was called with wrong arg") require.Equal(t, true, v.ProposeBlockCalled, "ProposeBlock(%d) was not called", slot) diff --git a/validator/client/testutil/mock_validator.go b/validator/client/testutil/mock_validator.go index 47c22dde89..15c4d17a53 100644 --- a/validator/client/testutil/mock_validator.go +++ b/validator/client/testutil/mock_validator.go @@ -3,6 +3,7 @@ package testutil import ( "bytes" "context" + "errors" "time" api "github.com/prysmaticlabs/prysm/v5/api/client" @@ -60,6 +61,8 @@ type FakeValidator struct { Km keymanager.IKeymanager graffiti string Tracker *beacon.NodeHealthTracker + AttSubmitted chan interface{} + BlockProposed chan interface{} } // Done for mocking. @@ -73,7 +76,7 @@ func (fv *FakeValidator) WaitForKeymanagerInitialization(_ context.Context) erro return nil } -// LogSyncCommitteeMessagesSubmitted -- +// LogSubmittedSyncCommitteeMessages -- func (fv *FakeValidator) LogSubmittedSyncCommitteeMessages() {} // WaitForChainStart for mocking. @@ -170,12 +173,20 @@ func (fv *FakeValidator) RolesAt(_ context.Context, slot primitives.Slot) (map[[ func (fv *FakeValidator) SubmitAttestation(_ context.Context, slot primitives.Slot, _ [fieldparams.BLSPubkeyLength]byte) { fv.AttestToBlockHeadCalled = true fv.AttestToBlockHeadArg1 = uint64(slot) + if fv.AttSubmitted != nil { + close(fv.AttSubmitted) + fv.AttSubmitted = nil + } } // ProposeBlock for mocking. func (fv *FakeValidator) ProposeBlock(_ context.Context, slot primitives.Slot, _ [fieldparams.BLSPubkeyLength]byte) { fv.ProposeBlockCalled = true fv.ProposeBlockArg1 = uint64(slot) + if fv.BlockProposed != nil { + close(fv.BlockProposed) + fv.BlockProposed = nil + } } // SubmitAggregateAndProof for mocking. @@ -248,9 +259,9 @@ func (fv *FakeValidator) PushProposerSettings(ctx context.Context, km keymanager ctx = nctx defer cancel() time.Sleep(fv.ProposerSettingWait) - if ctx.Err() == context.DeadlineExceeded { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { log.Error("deadline exceeded") - // can't return error or it will trigger a log.fatal + // can't return error as it will trigger a log.fatal return nil } @@ -284,19 +295,19 @@ func (fv *FakeValidator) SetProposerSettings(_ context.Context, settings *propos } // GetGraffiti for mocking -func (f *FakeValidator) GetGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte) ([]byte, error) { - return []byte(f.graffiti), nil +func (fv *FakeValidator) GetGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte) ([]byte, error) { + return []byte(fv.graffiti), nil } // SetGraffiti for mocking -func (f *FakeValidator) SetGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte, graffiti []byte) error { - f.graffiti = string(graffiti) +func (fv *FakeValidator) SetGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte, graffiti []byte) error { + fv.graffiti = string(graffiti) return nil } // DeleteGraffiti for mocking -func (f *FakeValidator) DeleteGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte) error { - f.graffiti = "" +func (fv *FakeValidator) DeleteGraffiti(_ context.Context, _ [fieldparams.BLSPubkeyLength]byte) error { + fv.graffiti = "" return nil }