diff --git a/beacon-chain/core/epoch/epoch_processing.go b/beacon-chain/core/epoch/epoch_processing.go index f5682244e3..8b501c8ee6 100644 --- a/beacon-chain/core/epoch/epoch_processing.go +++ b/beacon-chain/core/epoch/epoch_processing.go @@ -168,17 +168,9 @@ func ProcessSlashings(state *stateTrie.BeaconState) (*stateTrie.BeaconState, err totalSlashing += slashing } - checker := func(idx int, val *ethpb.Validator) (bool, error) { - correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch - if val.Slashed && correctEpoch { - return true, nil - } - return false, nil - } - // a callback is used here to apply the following actions to all validators // below equally. - err = state.ApplyToEveryValidator(checker, func(idx int, val *ethpb.Validator) error { + err = state.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, error) { correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch if val.Slashed && correctEpoch { minSlashing := mathutil.Min(totalSlashing*3, totalBalance) @@ -186,11 +178,11 @@ func ProcessSlashings(state *stateTrie.BeaconState) (*stateTrie.BeaconState, err penaltyNumerator := val.EffectiveBalance / increment * minSlashing penalty := penaltyNumerator / totalBalance * increment if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil { - return err + return false, err } - return nil + return true, nil } - return nil + return false, nil }) return state, err } @@ -246,7 +238,8 @@ func ProcessFinalUpdates(state *stateTrie.BeaconState) (*stateTrie.BeaconState, } bals := state.Balances() - checker := func(idx int, val *ethpb.Validator) (bool, error) { + // Update effective balances with hysteresis. + validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) { if val == nil { return false, fmt.Errorf("validator %d is nil in state", idx) } @@ -261,29 +254,14 @@ func ProcessFinalUpdates(state *stateTrie.BeaconState) (*stateTrie.BeaconState, if balance+downwardThreshold < val.EffectiveBalance || val.EffectiveBalance+upwardThreshold < balance { val.EffectiveBalance = params.BeaconConfig().MaxEffectiveBalance if val.EffectiveBalance > balance-balance%params.BeaconConfig().EffectiveBalanceIncrement { - return true, nil + val.EffectiveBalance = balance - balance%params.BeaconConfig().EffectiveBalanceIncrement } + return true, nil } return false, nil } - // Update effective balances with hysteresis. - updateEffectiveBalances := func(idx int, val *ethpb.Validator) error { - balance := bals[idx] - hysteresisInc := params.BeaconConfig().EffectiveBalanceIncrement / params.BeaconConfig().HysteresisQuotient - downwardThreshold := hysteresisInc * params.BeaconConfig().HysteresisDownwardMultiplier - upwardThreshold := hysteresisInc * params.BeaconConfig().HysteresisUpwardMultiplier - if balance+downwardThreshold < val.EffectiveBalance || val.EffectiveBalance+upwardThreshold < balance { - val.EffectiveBalance = params.BeaconConfig().MaxEffectiveBalance - if val.EffectiveBalance > balance-balance%params.BeaconConfig().EffectiveBalanceIncrement { - val.EffectiveBalance = balance - balance%params.BeaconConfig().EffectiveBalanceIncrement - } - return nil - } - return nil - } - - if err := state.ApplyToEveryValidator(checker, updateEffectiveBalances); err != nil { + if err := state.ApplyToEveryValidator(validatorFunc); err != nil { return nil, err } diff --git a/beacon-chain/core/epoch/precompute/slashing.go b/beacon-chain/core/epoch/precompute/slashing.go index 87332ece3a..6881a131e2 100644 --- a/beacon-chain/core/epoch/precompute/slashing.go +++ b/beacon-chain/core/epoch/precompute/slashing.go @@ -21,15 +21,7 @@ func ProcessSlashingsPrecompute(state *stateTrie.BeaconState, p *Balance) error totalSlashing += slashing } - checker := func(idx int, val *ethpb.Validator) (bool, error) { - correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch - if val.Slashed && correctEpoch { - return true, nil - } - return false, nil - } - - updateEffectiveBalances := func(idx int, val *ethpb.Validator) error { + validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) { correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch if val.Slashed && correctEpoch { minSlashing := mathutil.Min(totalSlashing*3, p.CurrentEpoch) @@ -37,12 +29,12 @@ func ProcessSlashingsPrecompute(state *stateTrie.BeaconState, p *Balance) error penaltyNumerator := val.EffectiveBalance / increment * minSlashing penalty := penaltyNumerator / p.CurrentEpoch * increment if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil { - return err + return false, err } - return nil + return true, nil } - return nil + return false, nil } - return state.ApplyToEveryValidator(checker, updateEffectiveBalances) + return state.ApplyToEveryValidator(validatorFunc) } diff --git a/beacon-chain/state/references_test.go b/beacon-chain/state/references_test.go index 7df1b9302a..3fe2f8974a 100644 --- a/beacon-chain/state/references_test.go +++ b/beacon-chain/state/references_test.go @@ -6,7 +6,6 @@ import ( "runtime/debug" "testing" - eth "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/go-bitfield" p2ppb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" "github.com/prysmaticlabs/prysm/shared/bytesutil" @@ -50,111 +49,6 @@ func TestStateReferenceSharing_Finalizer(t *testing.T) { } } -func TestStateReferenceCopy_NoUnexpectedValidatorMutation(t *testing.T) { - resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{EnableStateRefCopy: true}) - defer resetCfg() - - a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{}) - if err != nil { - t.Fatal(err) - } - - assertRefCount(t, a, validators, 1) - - // Add validator before copying state (so that a and b have shared data). - pubKey1, pubKey2 := [48]byte{29}, [48]byte{31} - err = a.AppendValidator(ð.Validator{ - PublicKey: pubKey1[:], - }) - if len(a.state.GetValidators()) != 1 { - t.Error("No validators found") - } - - // Copy, increases reference count. - b := a.Copy() - assertRefCount(t, a, validators, 2) - assertRefCount(t, b, validators, 2) - if len(b.state.GetValidators()) != 1 { - t.Error("No validators found") - } - - hasValidatorWithPubKey := func(state *p2ppb.BeaconState, key [48]byte) bool { - for _, val := range state.GetValidators() { - if reflect.DeepEqual(val.PublicKey, key[:]) { - return true - } - } - return false - } - - err = a.AppendValidator(ð.Validator{ - PublicKey: pubKey2[:], - }) - if err != nil { - t.Fatal(err) - } - - // Copy on write happened, reference counters are reset. - assertRefCount(t, a, validators, 1) - assertRefCount(t, b, validators, 1) - - valsA := a.state.GetValidators() - valsB := b.state.GetValidators() - if len(valsA) != 2 { - t.Errorf("Unexpected number of validators, want: %v, got: %v", 2, len(valsA)) - } - // Both validators are known to a. - if !hasValidatorWithPubKey(a.state, pubKey1) { - t.Errorf("Expected validator not found, want: %v", pubKey1) - } - if !hasValidatorWithPubKey(a.state, pubKey2) { - t.Errorf("Expected validator not found, want: %v", pubKey2) - } - // Only one validator is known to b. - if !hasValidatorWithPubKey(b.state, pubKey1) { - t.Errorf("Expected validator not found, want: %v", pubKey1) - } - if hasValidatorWithPubKey(b.state, pubKey2) { - t.Errorf("Unexpected validator found: %v", pubKey2) - } - if len(valsA) == len(valsB) { - t.Error("Unexpected state mutation") - } - - // Make sure that function applied to all validators in one state, doesn't affect another. - changedBalance := uint64(1) - for i, val := range valsA { - if val.EffectiveBalance == changedBalance { - t.Errorf("Unexpected effective balance, want: %v, got: %v", 0, valsA[i].EffectiveBalance) - } - } - for i, val := range valsB { - if val.EffectiveBalance == changedBalance { - t.Errorf("Unexpected effective balance, want: %v, got: %v", 0, valsB[i].EffectiveBalance) - } - } - // Applied to a, a and b share reference to the first validator, which shouldn't cause issues. - err = a.ApplyToEveryValidator(func(idx int, val *eth.Validator) (b bool, err error) { - return true, nil - }, func(idx int, val *eth.Validator) error { - val.EffectiveBalance = 1 - return nil - }) - if err != nil { - t.Fatal(err) - } - for i, val := range valsA { - if val.EffectiveBalance != changedBalance { - t.Errorf("Unexpected effective balance, want: %v, got: %v", changedBalance, valsA[i].EffectiveBalance) - } - } - for i, val := range valsB { - if val.EffectiveBalance == changedBalance { - t.Errorf("Unexpected mutation of effective balance, want: %v, got: %v", 0, valsB[i].EffectiveBalance) - } - } -} - func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) { resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{EnableStateRefCopy: true}) defer resetCfg() diff --git a/beacon-chain/state/setters.go b/beacon-chain/state/setters.go index b7ab870eff..d730b16dc2 100644 --- a/beacon-chain/state/setters.go +++ b/beacon-chain/state/setters.go @@ -299,8 +299,7 @@ func (b *BeaconState) SetValidators(val []*ethpb.Validator) error { // ApplyToEveryValidator applies the provided callback function to each validator in the // validator registry. -func (b *BeaconState) ApplyToEveryValidator(checker func(idx int, val *ethpb.Validator) (bool, error), - mutator func(idx int, val *ethpb.Validator) error) error { +func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator) (bool, error)) error { if !b.HasInnerState() { return ErrNilInnerState } @@ -308,35 +307,21 @@ func (b *BeaconState) ApplyToEveryValidator(checker func(idx int, val *ethpb.Val v := b.state.Validators if ref := b.sharedFieldReferences[validators]; ref.refs > 1 { // Perform a copy since this is a shared reference and we don't want to mutate others. - if featureconfig.Get().EnableStateRefCopy { - v = make([]*ethpb.Validator, len(b.state.Validators)) - copy(v, b.state.Validators) - } else { - v = b.Validators() - } + v = b.Validators() + ref.MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} } b.lock.RUnlock() - var changedVals []uint64 + changedVals := []uint64{} for i, val := range v { - changed, err := checker(i, val) + changed, err := f(i, val) if err != nil { return err } - if !changed { - continue + if changed { + changedVals = append(changedVals, uint64(i)) } - if featureconfig.Get().EnableStateRefCopy { - // copy if changing a reference - val = CopyValidator(val) - } - err = mutator(i, val) - if err != nil { - return err - } - changedVals = append(changedVals, uint64(i)) - v[i] = val } b.lock.Lock() @@ -363,12 +348,7 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e v := b.state.Validators if ref := b.sharedFieldReferences[validators]; ref.refs > 1 { // Perform a copy since this is a shared reference and we don't want to mutate others. - if featureconfig.Get().EnableStateRefCopy { - v = make([]*ethpb.Validator, len(b.state.Validators)) - copy(v, b.state.Validators) - } else { - v = b.Validators() - } + v = b.Validators() ref.MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} @@ -677,12 +657,7 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error { b.lock.RLock() vals := b.state.Validators if b.sharedFieldReferences[validators].refs > 1 { - if featureconfig.Get().EnableStateRefCopy { - vals = make([]*ethpb.Validator, len(b.state.Validators)) - copy(vals, b.state.Validators) - } else { - vals = b.Validators() - } + vals = b.Validators() b.sharedFieldReferences[validators].MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} }