From 143cb142bcda7af8e8621feedea2cf4b815123ae Mon Sep 17 00:00:00 2001 From: Nishant Das Date: Thu, 11 Feb 2021 04:52:45 +0800 Subject: [PATCH] Make Individual Validators Immutable (#8397) * initial POC * clean up Co-authored-by: Raul Jordan Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- beacon-chain/core/epoch/epoch_processing.go | 25 ++++++++------- .../core/epoch/precompute/slashing.go | 8 ++--- beacon-chain/state/getters.go | 25 ++++++++++++++- beacon-chain/state/references_test.go | 32 +++++++++++++++++++ beacon-chain/state/setters.go | 15 ++++----- 5 files changed, 79 insertions(+), 26 deletions(-) diff --git a/beacon-chain/core/epoch/epoch_processing.go b/beacon-chain/core/epoch/epoch_processing.go index 5cba16040b..c450491769 100644 --- a/beacon-chain/core/epoch/epoch_processing.go +++ b/beacon-chain/core/epoch/epoch_processing.go @@ -175,17 +175,17 @@ func ProcessSlashings(state *stateTrie.BeaconState) (*stateTrie.BeaconState, err // below equally. increment := params.BeaconConfig().EffectiveBalanceIncrement minSlashing := mathutil.Min(totalSlashing*params.BeaconConfig().ProportionalSlashingMultiplier, totalBalance) - err = state.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, error) { + err = state.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) { correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch if val.Slashed && correctEpoch { penaltyNumerator := val.EffectiveBalance / increment * minSlashing penalty := penaltyNumerator / totalBalance * increment if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil { - return false, err + return false, val, err } - return true, nil + return true, val, nil } - return false, nil + return false, val, nil }) return state, err } @@ -249,23 +249,24 @@ func ProcessFinalUpdates(state *stateTrie.BeaconState) (*stateTrie.BeaconState, bals := state.Balances() // Update effective balances with hysteresis. - validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) { + validatorFunc := func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) { if val == nil { - return false, fmt.Errorf("validator %d is nil in state", idx) + return false, nil, fmt.Errorf("validator %d is nil in state", idx) } if idx >= len(bals) { - return false, fmt.Errorf("validator index exceeds validator length in state %d >= %d", idx, len(state.Balances())) + return false, nil, fmt.Errorf("validator index exceeds validator length in state %d >= %d", idx, len(state.Balances())) } balance := bals[idx] if balance+downwardThreshold < val.EffectiveBalance || val.EffectiveBalance+upwardThreshold < balance { - val.EffectiveBalance = maxEffBalance - if val.EffectiveBalance > balance-balance%effBalanceInc { - val.EffectiveBalance = balance - balance%effBalanceInc + newVal := stateTrie.CopyValidator(val) + newVal.EffectiveBalance = maxEffBalance + if newVal.EffectiveBalance > balance-balance%effBalanceInc { + newVal.EffectiveBalance = balance - balance%effBalanceInc } - return true, nil + return true, newVal, nil } - return false, nil + return false, val, nil } if err := state.ApplyToEveryValidator(validatorFunc); err != nil { diff --git a/beacon-chain/core/epoch/precompute/slashing.go b/beacon-chain/core/epoch/precompute/slashing.go index 425c4a2bc6..9c60f3ff8c 100644 --- a/beacon-chain/core/epoch/precompute/slashing.go +++ b/beacon-chain/core/epoch/precompute/slashing.go @@ -42,17 +42,17 @@ func ProcessSlashingsPrecompute(state *stateTrie.BeaconState, pBal *Balance) err } increment := params.BeaconConfig().EffectiveBalanceIncrement - validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) { + validatorFunc := func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) { correctEpoch := epochToWithdraw == val.WithdrawableEpoch if val.Slashed && correctEpoch { penaltyNumerator := val.EffectiveBalance / increment * minSlashing penalty := penaltyNumerator / pBal.ActiveCurrentEpoch * increment if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil { - return false, err + return false, val, err } - return true, nil + return true, val, nil } - return false, nil + return false, val, nil } return state.ApplyToEveryValidator(validatorFunc) diff --git a/beacon-chain/state/getters.go b/beacon-chain/state/getters.go index 5cfbc34bb8..c8469cf431 100644 --- a/beacon-chain/state/getters.go +++ b/beacon-chain/state/getters.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/prysmaticlabs/eth2-types" + types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/go-bitfield" pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" @@ -592,6 +592,29 @@ func (b *BeaconState) validators() []*ethpb.Validator { return res } +// references of validators participating in consensus on the beacon chain. +// This assumes that a lock is already held on BeaconState. This does not +// copy fully and instead just copies the reference. +func (b *BeaconState) validatorsReferences() []*ethpb.Validator { + if !b.HasInnerState() { + return nil + } + if b.state.Validators == nil { + return nil + } + + res := make([]*ethpb.Validator, len(b.state.Validators)) + for i := 0; i < len(res); i++ { + validator := b.state.Validators[i] + if validator == nil { + continue + } + // copy validator reference instead. + res[i] = validator + } + return res +} + // ValidatorAtIndex is the validator at the provided index. func (b *BeaconState) ValidatorAtIndex(idx uint64) (*ethpb.Validator, error) { if !b.HasInnerState() { diff --git a/beacon-chain/state/references_test.go b/beacon-chain/state/references_test.go index 5c5ee7307a..0a2fa48e11 100644 --- a/beacon-chain/state/references_test.go +++ b/beacon-chain/state/references_test.go @@ -6,6 +6,7 @@ import ( "runtime/debug" "testing" + ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" "github.com/prysmaticlabs/go-bitfield" p2ppb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" "github.com/prysmaticlabs/prysm/shared/bytesutil" @@ -272,6 +273,37 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) { assertRefCount(t, b, previousEpochAttestations, 1) } +func TestValidatorReferences_RemainsConsistent(t *testing.T) { + a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{ + Validators: []*ethpb.Validator{ + {PublicKey: []byte{'A'}}, + {PublicKey: []byte{'B'}}, + {PublicKey: []byte{'C'}}, + {PublicKey: []byte{'D'}}, + {PublicKey: []byte{'E'}}, + }, + }) + require.NoError(t, err) + + // Create a second state. + b := a.Copy() + + // Update First Validator. + assert.NoError(t, a.UpdateValidatorAtIndex(0, ðpb.Validator{PublicKey: []byte{'Z'}})) + + assert.DeepNotEqual(t, a.state.Validators[0], b.state.Validators[0], "validators are equal when they are supposed to be different") + // Modify all validators from copied state. + assert.NoError(t, b.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) { + return true, ðpb.Validator{PublicKey: []byte{'V'}}, nil + })) + + // Ensure reference is properly accounted for. + assert.NoError(t, a.ReadFromEveryValidator(func(idx int, val ReadOnlyValidator) error { + assert.NotEqual(t, bytesutil.ToBytes48([]byte{'V'}), val.PublicKey()) + return nil + })) +} + // assertRefCount checks whether reference count for a given state // at a given index is equal to expected amount. func assertRefCount(t *testing.T, b *BeaconState, idx fieldIndex, want uint) { diff --git a/beacon-chain/state/setters.go b/beacon-chain/state/setters.go index 28a74a3cc8..8c4234c6a8 100644 --- a/beacon-chain/state/setters.go +++ b/beacon-chain/state/setters.go @@ -304,28 +304,27 @@ func (b *BeaconState) SetValidators(val []*ethpb.Validator) error { // ApplyToEveryValidator applies the provided callback function to each validator in the // validator registry. -func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator) (bool, error)) error { +func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error)) error { if !b.HasInnerState() { return ErrNilInnerState } b.lock.Lock() v := b.state.Validators if ref := b.sharedFieldReferences[validators]; ref.Refs() > 1 { - // Perform a copy since this is a shared reference and we don't want to mutate others. - v = b.validators() - + v = b.validatorsReferences() ref.MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} } b.lock.Unlock() var changedVals []uint64 for i, val := range v { - changed, err := f(i, val) + changed, newVal, err := f(i, val) if err != nil { return err } if changed { changedVals = append(changedVals, uint64(i)) + v[i] = newVal } } @@ -353,9 +352,7 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e v := b.state.Validators if ref := b.sharedFieldReferences[validators]; ref.Refs() > 1 { - // Perform a copy since this is a shared reference and we don't want to mutate others. - v = b.validators() - + v = b.validatorsReferences() ref.MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} } @@ -632,7 +629,7 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error { vals := b.state.Validators if b.sharedFieldReferences[validators].Refs() > 1 { - vals = b.validators() + vals = b.validatorsReferences() b.sharedFieldReferences[validators].MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} }