Add Reference Copying (#5245)

* add state changes
* add new changes
* add flag
* lint
* add flag
* change to correct bool
* fixing and consolidating trie
* lint
* Apply suggestions from code review
* Merge refs/heads/master into addBetterCopying
* Merge branch 'master' into addBetterCopying
* refCopy -> stateRefCopy
* Merge refs/heads/master into addBetterCopying
* tests whether unexpected mutation of validators within state is avoided
* Merge branch 'addBetterCopying' of github.com:prysmaticlabs/prysm into addBetterCopying
* remove unnecessary fields
* gazelle
* updates test
* avoid unexpected mutation in block roots on refcopy
* avoid unexpected mutation in state roots on refcopy
* Merge refs/heads/master into addBetterCopying
* Merge branch 'master' into addBetterCopying
* fix test
* randao tests
* simplify tests
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* test cur/prev attestations mutation
* Merge branch 'addBetterCopying' of github.com:prysmaticlabs/prysm into addBetterCopying
* gazelle
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* fixes tests
* minor naming update
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
* Merge refs/heads/master into addBetterCopying
This commit is contained in:
Nishant Das
2020-04-19 21:57:43 +08:00
committed by GitHub
parent dee3f02e2c
commit 626b3e0c66
12 changed files with 678 additions and 96 deletions

View File

@@ -168,9 +168,17 @@ func ProcessSlashings(state *stateTrie.BeaconState) (*stateTrie.BeaconState, err
totalSlashing += slashing
}
checker := func(idx int, val *ethpb.Validator) (bool, error) {
correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch
if val.Slashed && correctEpoch {
return true, nil
}
return false, nil
}
// a callback is used here to apply the following actions to all validators
// below equally.
err = state.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, error) {
err = state.ApplyToEveryValidator(checker, func(idx int, val *ethpb.Validator) error {
correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch
if val.Slashed && correctEpoch {
minSlashing := mathutil.Min(totalSlashing*3, totalBalance)
@@ -178,11 +186,11 @@ func ProcessSlashings(state *stateTrie.BeaconState) (*stateTrie.BeaconState, err
penaltyNumerator := val.EffectiveBalance / increment * minSlashing
penalty := penaltyNumerator / totalBalance * increment
if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil {
return false, err
return err
}
return true, nil
return nil
}
return false, nil
return nil
})
return state, err
}
@@ -238,8 +246,7 @@ func ProcessFinalUpdates(state *stateTrie.BeaconState) (*stateTrie.BeaconState,
}
bals := state.Balances()
// Update effective balances with hysteresis.
validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) {
checker := func(idx int, val *ethpb.Validator) (bool, error) {
if val == nil {
return false, fmt.Errorf("validator %d is nil in state", idx)
}
@@ -254,14 +261,29 @@ func ProcessFinalUpdates(state *stateTrie.BeaconState) (*stateTrie.BeaconState,
if balance+downwardThreshold < val.EffectiveBalance || val.EffectiveBalance+upwardThreshold < balance {
val.EffectiveBalance = params.BeaconConfig().MaxEffectiveBalance
if val.EffectiveBalance > balance-balance%params.BeaconConfig().EffectiveBalanceIncrement {
val.EffectiveBalance = balance - balance%params.BeaconConfig().EffectiveBalanceIncrement
return true, nil
}
return true, nil
}
return false, nil
}
// Update effective balances with hysteresis.
updateEffectiveBalances := func(idx int, val *ethpb.Validator) error {
balance := bals[idx]
hysteresisInc := params.BeaconConfig().EffectiveBalanceIncrement / params.BeaconConfig().HysteresisQuotient
downwardThreshold := hysteresisInc * params.BeaconConfig().HysteresisDownwardMultiplier
upwardThreshold := hysteresisInc * params.BeaconConfig().HysteresisUpwardMultiplier
if err := state.ApplyToEveryValidator(validatorFunc); err != nil {
if balance+downwardThreshold < val.EffectiveBalance || val.EffectiveBalance+upwardThreshold < balance {
val.EffectiveBalance = params.BeaconConfig().MaxEffectiveBalance
if val.EffectiveBalance > balance-balance%params.BeaconConfig().EffectiveBalanceIncrement {
val.EffectiveBalance = balance - balance%params.BeaconConfig().EffectiveBalanceIncrement
}
return nil
}
return nil
}
if err := state.ApplyToEveryValidator(checker, updateEffectiveBalances); err != nil {
return nil, err
}

View File

@@ -21,7 +21,15 @@ func ProcessSlashingsPrecompute(state *stateTrie.BeaconState, p *Balance) error
totalSlashing += slashing
}
validatorFunc := func(idx int, val *ethpb.Validator) (bool, error) {
checker := func(idx int, val *ethpb.Validator) (bool, error) {
correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch
if val.Slashed && correctEpoch {
return true, nil
}
return false, nil
}
updateEffectiveBalances := func(idx int, val *ethpb.Validator) error {
correctEpoch := (currentEpoch + exitLength/2) == val.WithdrawableEpoch
if val.Slashed && correctEpoch {
minSlashing := mathutil.Min(totalSlashing*3, p.CurrentEpoch)
@@ -29,12 +37,12 @@ func ProcessSlashingsPrecompute(state *stateTrie.BeaconState, p *Balance) error
penaltyNumerator := val.EffectiveBalance / increment * minSlashing
penalty := penaltyNumerator / p.CurrentEpoch * increment
if err := helpers.DecreaseBalance(state, uint64(idx), penalty); err != nil {
return false, err
return err
}
return true, nil
return nil
}
return false, nil
return nil
}
return state.ApplyToEveryValidator(validatorFunc)
return state.ApplyToEveryValidator(checker, updateEffectiveBalances)
}

View File

@@ -48,11 +48,13 @@ go_test(
"//beacon-chain/state/stateutil:go_default_library",
"//proto/beacon/p2p/v1:go_default_library",
"//shared/bytesutil:go_default_library",
"//shared/featureconfig:go_default_library",
"//shared/interop:go_default_library",
"//shared/params:go_default_library",
"//shared/testutil:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
"@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@com_github_prysmaticlabs_go_ssz//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
],

View File

@@ -266,3 +266,21 @@ func CopySignedVoluntaryExit(exit *ethpb.SignedVoluntaryExit) *ethpb.SignedVolun
Signature: bytesutil.SafeCopyBytes(exit.Signature),
}
}
// CopyValidator copies the provided validator.
func CopyValidator(val *ethpb.Validator) *ethpb.Validator {
pubKey := make([]byte, len(val.PublicKey))
copy(pubKey, val.PublicKey)
withdrawalCreds := make([]byte, len(val.WithdrawalCredentials))
copy(withdrawalCreds, val.WithdrawalCredentials)
return &ethpb.Validator{
PublicKey: pubKey[:],
WithdrawalCredentials: withdrawalCreds,
EffectiveBalance: val.EffectiveBalance,
Slashed: val.Slashed,
ActivationEligibilityEpoch: val.ActivationEligibilityEpoch,
ActivationEpoch: val.ActivationEpoch,
ExitEpoch: val.ExitEpoch,
WithdrawableEpoch: val.WithdrawableEpoch,
}
}

View File

@@ -113,10 +113,10 @@ func (f *FieldTrie) CopyTrie() *FieldTrie {
switch f.field {
case randaoMixes:
dstFieldTrie = memorypool.GetRandaoMixesTrie(len(f.fieldLayers))
case blockRoots:
dstFieldTrie = memorypool.GetBlockRootsTrie(len(f.fieldLayers))
case stateRoots:
dstFieldTrie = memorypool.GetStateRootsTrie(len(f.fieldLayers))
case blockRoots, stateRoots:
dstFieldTrie = memorypool.GetRootsTrie(len(f.fieldLayers))
case validators:
dstFieldTrie = memorypool.GetValidatorsTrie(len(f.fieldLayers))
default:
dstFieldTrie = make([][]*[32]byte, len(f.fieldLayers))
}

View File

@@ -79,6 +79,11 @@ func (v *ReadOnlyValidator) Slashed() bool {
return v.validator.Slashed
}
// CopyValidator returns the copy of the read only validator.
func (v *ReadOnlyValidator) CopyValidator() *ethpb.Validator {
return CopyValidator(v.validator)
}
// InnerStateUnsafe returns the pointer value of the underlying
// beacon state proto object, bypassing immutability. Use with care.
func (b *BeaconState) InnerStateUnsafe() *pbp2p.BeaconState {
@@ -366,20 +371,7 @@ func (b *BeaconState) Validators() []*ethpb.Validator {
if val == nil {
continue
}
pubKey := make([]byte, len(val.PublicKey))
copy(pubKey, val.PublicKey)
withdrawalCreds := make([]byte, len(val.WithdrawalCredentials))
copy(withdrawalCreds, val.WithdrawalCredentials)
res[i] = &ethpb.Validator{
PublicKey: pubKey[:],
WithdrawalCredentials: withdrawalCreds,
EffectiveBalance: val.EffectiveBalance,
Slashed: val.Slashed,
ActivationEligibilityEpoch: val.ActivationEligibilityEpoch,
ActivationEpoch: val.ActivationEpoch,
ExitEpoch: val.ExitEpoch,
WithdrawableEpoch: val.WithdrawableEpoch,
}
res[i] = CopyValidator(val)
}
return res
}

View File

@@ -1,10 +1,16 @@
package state
import (
"reflect"
"runtime"
"runtime/debug"
"testing"
eth "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/go-bitfield"
p2ppb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
"github.com/prysmaticlabs/prysm/shared/featureconfig"
)
func TestStateReferenceSharing_Finalizer(t *testing.T) {
@@ -43,3 +49,465 @@ func TestStateReferenceSharing_Finalizer(t *testing.T) {
t.Error("Expected 1 shared reference to randao mix for both a and b")
}
}
func TestStateReferenceCopy_NoUnexpectedValidatorMutation(t *testing.T) {
// Assert that feature is enabled.
if cfg := featureconfig.Get(); !cfg.EnableStateRefCopy {
cfg.EnableStateRefCopy = true
featureconfig.Init(cfg)
defer func() {
cfg := featureconfig.Get()
cfg.EnableStateRefCopy = false
featureconfig.Init(cfg)
}()
}
a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{})
if err != nil {
t.Fatal(err)
}
assertRefCount(t, a, validators, 1)
// Add validator before copying state (so that a and b have shared data).
pubKey1, pubKey2 := [48]byte{29}, [48]byte{31}
err = a.AppendValidator(&eth.Validator{
PublicKey: pubKey1[:],
})
if len(a.state.GetValidators()) != 1 {
t.Error("No validators found")
}
// Copy, increases reference count.
b := a.Copy()
assertRefCount(t, a, validators, 2)
assertRefCount(t, b, validators, 2)
if len(b.state.GetValidators()) != 1 {
t.Error("No validators found")
}
hasValidatorWithPubKey := func(state *p2ppb.BeaconState, key [48]byte) bool {
for _, val := range state.GetValidators() {
if reflect.DeepEqual(val.PublicKey, key[:]) {
return true
}
}
return false
}
err = a.AppendValidator(&eth.Validator{
PublicKey: pubKey2[:],
})
if err != nil {
t.Fatal(err)
}
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, validators, 1)
assertRefCount(t, b, validators, 1)
valsA := a.state.GetValidators()
valsB := b.state.GetValidators()
if len(valsA) != 2 {
t.Errorf("Unexpected number of validators, want: %v, got: %v", 2, len(valsA))
}
// Both validators are known to a.
if !hasValidatorWithPubKey(a.state, pubKey1) {
t.Errorf("Expected validator not found, want: %v", pubKey1)
}
if !hasValidatorWithPubKey(a.state, pubKey2) {
t.Errorf("Expected validator not found, want: %v", pubKey2)
}
// Only one validator is known to b.
if !hasValidatorWithPubKey(b.state, pubKey1) {
t.Errorf("Expected validator not found, want: %v", pubKey1)
}
if hasValidatorWithPubKey(b.state, pubKey2) {
t.Errorf("Unexpected validator found: %v", pubKey2)
}
if len(valsA) == len(valsB) {
t.Error("Unexpected state mutation")
}
// Make sure that function applied to all validators in one state, doesn't affect another.
changedBalance := uint64(1)
for i, val := range valsA {
if val.EffectiveBalance == changedBalance {
t.Errorf("Unexpected effective balance, want: %v, got: %v", 0, valsA[i].EffectiveBalance)
}
}
for i, val := range valsB {
if val.EffectiveBalance == changedBalance {
t.Errorf("Unexpected effective balance, want: %v, got: %v", 0, valsB[i].EffectiveBalance)
}
}
// Applied to a, a and b share reference to the first validator, which shouldn't cause issues.
err = a.ApplyToEveryValidator(func(idx int, val *eth.Validator) (b bool, err error) {
return true, nil
}, func(idx int, val *eth.Validator) error {
val.EffectiveBalance = 1
return nil
})
if err != nil {
t.Fatal(err)
}
for i, val := range valsA {
if val.EffectiveBalance != changedBalance {
t.Errorf("Unexpected effective balance, want: %v, got: %v", changedBalance, valsA[i].EffectiveBalance)
}
}
for i, val := range valsB {
if val.EffectiveBalance == changedBalance {
t.Errorf("Unexpected mutation of effective balance, want: %v, got: %v", 0, valsB[i].EffectiveBalance)
}
}
}
func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
// Assert that feature is enabled.
if cfg := featureconfig.Get(); !cfg.EnableStateRefCopy {
cfg.EnableStateRefCopy = true
featureconfig.Init(cfg)
defer func() {
cfg := featureconfig.Get()
cfg.EnableStateRefCopy = false
featureconfig.Init(cfg)
}()
}
root1, root2 := bytesutil.ToBytes32([]byte("foo")), bytesutil.ToBytes32([]byte("bar"))
a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{
BlockRoots: [][]byte{
root1[:],
},
StateRoots: [][]byte{
root1[:],
},
})
if err != nil {
t.Fatal(err)
}
assertRefCount(t, a, blockRoots, 1)
assertRefCount(t, a, stateRoots, 1)
// Copy, increases reference count.
b := a.Copy()
assertRefCount(t, a, blockRoots, 2)
assertRefCount(t, a, stateRoots, 2)
assertRefCount(t, b, blockRoots, 2)
assertRefCount(t, b, stateRoots, 2)
if len(b.state.GetBlockRoots()) != 1 {
t.Error("No block roots found")
}
if len(b.state.GetStateRoots()) != 1 {
t.Error("No state roots found")
}
// Assert shared state.
blockRootsA := a.state.GetBlockRoots()
stateRootsA := a.state.GetStateRoots()
blockRootsB := b.state.GetBlockRoots()
stateRootsB := b.state.GetStateRoots()
if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 {
t.Errorf("Unexpected number of block roots, want: %v", 1)
}
if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 {
t.Errorf("Unexpected number of state roots, want: %v", 1)
}
assertValFound(t, blockRootsA, root1[:])
assertValFound(t, blockRootsB, root1[:])
assertValFound(t, stateRootsA, root1[:])
assertValFound(t, stateRootsB, root1[:])
// Mutator should only affect calling state: a.
err = a.UpdateBlockRootAtIndex(0, root2)
if err != nil {
t.Fatal(err)
}
err = a.UpdateStateRootAtIndex(0, root2)
if err != nil {
t.Fatal(err)
}
// Assert no shared state mutation occurred only on state a (copy on write).
assertValNotFound(t, a.state.GetBlockRoots(), root1[:])
assertValNotFound(t, a.state.GetStateRoots(), root1[:])
assertValFound(t, a.state.GetBlockRoots(), root2[:])
assertValFound(t, a.state.GetStateRoots(), root2[:])
assertValFound(t, b.state.GetBlockRoots(), root1[:])
assertValFound(t, b.state.GetStateRoots(), root1[:])
if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 {
t.Errorf("Unexpected number of block roots, want: %v", 1)
}
if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 {
t.Errorf("Unexpected number of state roots, want: %v", 1)
}
if !reflect.DeepEqual(a.state.GetBlockRoots()[0], root2[:]) {
t.Errorf("Expected mutation not found")
}
if !reflect.DeepEqual(a.state.GetStateRoots()[0], root2[:]) {
t.Errorf("Expected mutation not found")
}
if !reflect.DeepEqual(blockRootsB[0], root1[:]) {
t.Errorf("Unexpected mutation found")
}
if !reflect.DeepEqual(stateRootsB[0], root1[:]) {
t.Errorf("Unexpected mutation found")
}
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, blockRoots, 1)
assertRefCount(t, a, stateRoots, 1)
assertRefCount(t, b, blockRoots, 1)
assertRefCount(t, b, stateRoots, 1)
}
func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
// Assert that feature is enabled.
if cfg := featureconfig.Get(); !cfg.EnableStateRefCopy {
cfg.EnableStateRefCopy = true
featureconfig.Init(cfg)
defer func() {
cfg := featureconfig.Get()
cfg.EnableStateRefCopy = false
featureconfig.Init(cfg)
}()
}
val1, val2 := []byte("foo"), []byte("bar")
a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{
RandaoMixes: [][]byte{
val1,
},
})
if err != nil {
t.Fatal(err)
}
assertRefCount(t, a, randaoMixes, 1)
// Copy, increases reference count.
b := a.Copy()
assertRefCount(t, a, randaoMixes, 2)
assertRefCount(t, b, randaoMixes, 2)
if len(b.state.GetRandaoMixes()) != 1 {
t.Error("No randao mixes found")
}
// Assert shared state.
mixesA := a.state.GetRandaoMixes()
mixesB := b.state.GetRandaoMixes()
if len(mixesA) != len(mixesB) || len(mixesA) < 1 {
t.Errorf("Unexpected number of mix values, want: %v", 1)
}
assertValFound(t, mixesA, val1)
assertValFound(t, mixesB, val1)
// Mutator should only affect calling state: a.
err = a.UpdateRandaoMixesAtIndex(0, val2)
if err != nil {
t.Fatal(err)
}
// Assert no shared state mutation occurred only on state a (copy on write).
if len(mixesA) != len(mixesB) || len(mixesA) < 1 {
t.Errorf("Unexpected number of mix values, want: %v", 1)
}
assertValFound(t, a.state.GetRandaoMixes(), val2)
assertValNotFound(t, a.state.GetRandaoMixes(), val1)
assertValFound(t, b.state.GetRandaoMixes(), val1)
assertValNotFound(t, b.state.GetRandaoMixes(), val2)
assertValFound(t, mixesB, val1)
assertValNotFound(t, mixesB, val2)
if !reflect.DeepEqual(a.state.GetRandaoMixes()[0], val2) {
t.Errorf("Expected mutation not found")
}
if !reflect.DeepEqual(mixesB[0], val1) {
t.Errorf("Unexpected mutation found")
}
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, randaoMixes, 1)
assertRefCount(t, b, randaoMixes, 1)
}
func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
// Assert that feature is enabled.
if cfg := featureconfig.Get(); !cfg.EnableStateRefCopy {
cfg.EnableStateRefCopy = true
featureconfig.Init(cfg)
defer func() {
cfg := featureconfig.Get()
cfg.EnableStateRefCopy = false
featureconfig.Init(cfg)
}()
}
assertAttFound := func(vals []*p2ppb.PendingAttestation, val uint64) {
for i := range vals {
if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) {
return
}
}
t.Log(string(debug.Stack()))
t.Fatalf("Expected attestation not found (%v), want: %v", vals, val)
}
assertAttNotFound := func(vals []*p2ppb.PendingAttestation, val uint64) {
for i := range vals {
if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) {
t.Log(string(debug.Stack()))
t.Fatalf("Unexpected attestation found (%v): %v", vals, val)
return
}
}
}
a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{})
if err != nil {
t.Fatal(err)
}
assertRefCount(t, a, previousEpochAttestations, 1)
assertRefCount(t, a, currentEpochAttestations, 1)
// Update initial state.
atts := []*p2ppb.PendingAttestation{
{AggregationBits: bitfield.NewBitlist(1),},
{AggregationBits: bitfield.NewBitlist(2),},
}
if err := a.SetPreviousEpochAttestations(atts[:1]); err != nil {
t.Fatal(err)
}
if err := a.SetCurrentEpochAttestations(atts[:1]); err != nil {
t.Fatal(err)
}
if len(a.CurrentEpochAttestations()) != 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
if len(a.PreviousEpochAttestations()) != 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
// Copy, increases reference count.
b := a.Copy()
assertRefCount(t, a, previousEpochAttestations, 2)
assertRefCount(t, a, currentEpochAttestations, 2)
assertRefCount(t, b, previousEpochAttestations, 2)
assertRefCount(t, b, currentEpochAttestations, 2)
if len(b.state.GetPreviousEpochAttestations()) != 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
if len(b.state.GetCurrentEpochAttestations()) != 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
// Assert shared state.
curAttsA := a.state.GetCurrentEpochAttestations()
prevAttsA := a.state.GetPreviousEpochAttestations()
curAttsB := b.state.GetCurrentEpochAttestations()
prevAttsB := b.state.GetPreviousEpochAttestations()
if len(curAttsA) != len(curAttsB) || len(curAttsA) < 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
if len(prevAttsA) != len(prevAttsB) || len(prevAttsA) < 1 {
t.Errorf("Unexpected number of attestations, want: %v", 1)
}
assertAttFound(curAttsA, 1)
assertAttFound(prevAttsA, 1)
assertAttFound(curAttsB, 1)
assertAttFound(prevAttsB, 1)
// Extends state a attestations.
if err := a.AppendCurrentEpochAttestations(atts[1]); err != nil {
t.Fatal(err)
}
if err := a.AppendPreviousEpochAttestations(atts[1]); err != nil {
t.Fatal(err)
}
if len(a.CurrentEpochAttestations()) != 2 {
t.Errorf("Unexpected number of attestations, want: %v", 2)
}
if len(a.PreviousEpochAttestations()) != 2 {
t.Errorf("Unexpected number of attestations, want: %v", 2)
}
assertAttFound(a.state.GetCurrentEpochAttestations(), 1)
assertAttFound(a.state.GetPreviousEpochAttestations(), 1)
assertAttFound(a.state.GetCurrentEpochAttestations(), 2)
assertAttFound(a.state.GetPreviousEpochAttestations(), 2)
assertAttFound(b.state.GetCurrentEpochAttestations(), 1)
assertAttFound(b.state.GetPreviousEpochAttestations(), 1)
assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2)
assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2)
// Mutator should only affect calling state: a.
applyToEveryAttestation := func(state *p2ppb.BeaconState) {
// One MUST copy on write.
atts = make([]*p2ppb.PendingAttestation, len(state.CurrentEpochAttestations))
copy(atts, state.CurrentEpochAttestations)
state.CurrentEpochAttestations = atts
for i := range state.GetCurrentEpochAttestations() {
att := CopyPendingAttestation(state.CurrentEpochAttestations[i])
att.AggregationBits = bitfield.NewBitlist(3)
state.CurrentEpochAttestations[i] = att
}
atts = make([]*p2ppb.PendingAttestation, len(state.PreviousEpochAttestations))
copy(atts, state.PreviousEpochAttestations)
state.PreviousEpochAttestations = atts
for i := range state.GetPreviousEpochAttestations() {
att := CopyPendingAttestation(state.PreviousEpochAttestations[i])
att.AggregationBits = bitfield.NewBitlist(3)
state.PreviousEpochAttestations[i] = att
}
}
applyToEveryAttestation(a.state)
// Assert no shared state mutation occurred only on state a (copy on write).
assertAttFound(a.state.GetCurrentEpochAttestations(), 3)
assertAttFound(a.state.GetPreviousEpochAttestations(), 3)
assertAttNotFound(a.state.GetCurrentEpochAttestations(), 1)
assertAttNotFound(a.state.GetPreviousEpochAttestations(), 1)
assertAttNotFound(a.state.GetCurrentEpochAttestations(), 2)
assertAttNotFound(a.state.GetPreviousEpochAttestations(), 2)
// State b must be unaffected.
assertAttNotFound(b.state.GetCurrentEpochAttestations(), 3)
assertAttNotFound(b.state.GetPreviousEpochAttestations(), 3)
assertAttFound(b.state.GetCurrentEpochAttestations(), 1)
assertAttFound(b.state.GetPreviousEpochAttestations(), 1)
assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2)
assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2)
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, currentEpochAttestations, 1)
assertRefCount(t, b, currentEpochAttestations, 1)
assertRefCount(t, a, previousEpochAttestations, 1)
assertRefCount(t, b, previousEpochAttestations, 1)
}
// assertRefCount checks whether reference count for a given state
// at a given index is equal to expected amount.
func assertRefCount(t *testing.T, b *BeaconState, idx fieldIndex, want uint) {
if cnt := b.sharedFieldReferences[idx].refs; cnt != want {
t.Errorf("Unexpected count of references for index %d, want: %v, got: %v", idx, want, cnt)
}
}
// assertValFound checks whether item with a given value exists in list.
func assertValFound(t *testing.T, vals [][]byte, val []byte) {
for i := range vals {
if reflect.DeepEqual(vals[i], val) {
return
}
}
t.Log(string(debug.Stack()))
t.Fatalf("Expected value not found (%v), want: %v", vals, val)
}
// assertValNotFound checks whether item with a given value doesn't exist in list.
func assertValNotFound(t *testing.T, vals [][]byte, val []byte) {
for i := range vals {
if reflect.DeepEqual(vals[i], val) {
t.Log(string(debug.Stack()))
t.Errorf("Unexpected value found (%v),: %v", vals, val)
return
}
}
}

View File

@@ -9,7 +9,9 @@ import (
"github.com/prysmaticlabs/go-bitfield"
coreutils "github.com/prysmaticlabs/prysm/beacon-chain/core/state/stateutils"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/featureconfig"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/memorypool"
)
// SetGenesisTime for the beacon state.
@@ -106,7 +108,12 @@ func (b *BeaconState) UpdateBlockRootAtIndex(idx uint64, blockRoot [32]byte) err
r := b.state.BlockRoots
if ref := b.sharedFieldReferences[blockRoots]; ref.refs > 1 {
// Copy on write since this is a shared array.
r = b.BlockRoots()
if featureconfig.Get().EnableStateRefCopy {
r = make([][]byte, len(b.state.BlockRoots))
copy(r, b.state.BlockRoots)
} else {
r = b.BlockRoots()
}
ref.MinusRef()
b.sharedFieldReferences[blockRoots] = &reference{refs: 1}
@@ -158,7 +165,12 @@ func (b *BeaconState) UpdateStateRootAtIndex(idx uint64, stateRoot [32]byte) err
r := b.state.StateRoots
if ref := b.sharedFieldReferences[stateRoots]; ref.refs > 1 {
// Perform a copy since this is a shared reference and we don't want to mutate others.
r = b.StateRoots()
if featureconfig.Get().EnableStateRefCopy {
r = make([][]byte, len(b.state.StateRoots))
copy(r, b.state.StateRoots)
} else {
r = b.StateRoots()
}
ref.MinusRef()
b.sharedFieldReferences[stateRoots] = &reference{refs: 1}
@@ -234,7 +246,12 @@ func (b *BeaconState) AppendEth1DataVotes(val *ethpb.Eth1Data) error {
b.lock.RLock()
votes := b.state.Eth1DataVotes
if b.sharedFieldReferences[eth1DataVotes].refs > 1 {
votes = b.Eth1DataVotes()
if featureconfig.Get().EnableStateRefCopy {
votes = make([]*ethpb.Eth1Data, len(b.state.Eth1DataVotes))
copy(votes, b.state.Eth1DataVotes)
} else {
votes = b.Eth1DataVotes()
}
b.sharedFieldReferences[eth1DataVotes].MinusRef()
b.sharedFieldReferences[eth1DataVotes] = &reference{refs: 1}
}
@@ -282,7 +299,8 @@ func (b *BeaconState) SetValidators(val []*ethpb.Validator) error {
// ApplyToEveryValidator applies the provided callback function to each validator in the
// validator registry.
func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator) (bool, error)) error {
func (b *BeaconState) ApplyToEveryValidator(checker func(idx int, val *ethpb.Validator) (bool, error),
mutator func(idx int, val *ethpb.Validator) error) error {
if !b.HasInnerState() {
return ErrNilInnerState
}
@@ -290,21 +308,35 @@ func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator
v := b.state.Validators
if ref := b.sharedFieldReferences[validators]; ref.refs > 1 {
// Perform a copy since this is a shared reference and we don't want to mutate others.
v = b.Validators()
if featureconfig.Get().EnableStateRefCopy {
v = make([]*ethpb.Validator, len(b.state.Validators))
copy(v, b.state.Validators)
} else {
v = b.Validators()
}
ref.MinusRef()
b.sharedFieldReferences[validators] = &reference{refs: 1}
}
b.lock.RUnlock()
changedVals := []uint64{}
var changedVals []uint64
for i, val := range v {
changed, err := f(i, val)
changed, err := checker(i, val)
if err != nil {
return err
}
if changed {
changedVals = append(changedVals, uint64(i))
if !changed {
continue
}
if featureconfig.Get().EnableStateRefCopy {
// copy if changing a reference
val = CopyValidator(val)
}
err = mutator(i, val)
if err != nil {
return err
}
changedVals = append(changedVals, uint64(i))
v[i] = val
}
b.lock.Lock()
@@ -331,7 +363,12 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e
v := b.state.Validators
if ref := b.sharedFieldReferences[validators]; ref.refs > 1 {
// Perform a copy since this is a shared reference and we don't want to mutate others.
v = b.Validators()
if featureconfig.Get().EnableStateRefCopy {
v = make([]*ethpb.Validator, len(b.state.Validators))
copy(v, b.state.Validators)
} else {
v = b.Validators()
}
ref.MinusRef()
b.sharedFieldReferences[validators] = &reference{refs: 1}
@@ -438,7 +475,12 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(idx uint64, val []byte) error {
b.lock.RLock()
mixes := b.state.RandaoMixes
if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 {
mixes = b.RandaoMixes()
if featureconfig.Get().EnableStateRefCopy {
mixes = memorypool.GetDoubleByteSlice(len(b.state.RandaoMixes))
copy(mixes, b.state.RandaoMixes)
} else {
mixes = b.RandaoMixes()
}
b.sharedFieldReferences[randaoMixes].MinusRef()
b.sharedFieldReferences[randaoMixes] = &reference{refs: 1}
}
@@ -547,7 +589,12 @@ func (b *BeaconState) AppendHistoricalRoots(root [32]byte) error {
b.lock.RLock()
roots := b.state.HistoricalRoots
if b.sharedFieldReferences[historicalRoots].refs > 1 {
roots = b.HistoricalRoots()
if featureconfig.Get().EnableStateRefCopy {
roots = make([][]byte, len(b.state.HistoricalRoots))
copy(roots, b.state.HistoricalRoots)
} else {
roots = b.HistoricalRoots()
}
b.sharedFieldReferences[historicalRoots].MinusRef()
b.sharedFieldReferences[historicalRoots] = &reference{refs: 1}
}
@@ -571,7 +618,12 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati
atts := b.state.CurrentEpochAttestations
if b.sharedFieldReferences[currentEpochAttestations].refs > 1 {
atts = b.CurrentEpochAttestations()
if featureconfig.Get().EnableStateRefCopy {
atts = make([]*pbp2p.PendingAttestation, len(b.state.CurrentEpochAttestations))
copy(atts, b.state.CurrentEpochAttestations)
} else {
atts = b.CurrentEpochAttestations()
}
b.sharedFieldReferences[currentEpochAttestations].MinusRef()
b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1}
}
@@ -595,7 +647,12 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat
b.lock.RLock()
atts := b.state.PreviousEpochAttestations
if b.sharedFieldReferences[previousEpochAttestations].refs > 1 {
atts = b.PreviousEpochAttestations()
if featureconfig.Get().EnableStateRefCopy {
atts = make([]*pbp2p.PendingAttestation, len(b.state.PreviousEpochAttestations))
copy(atts, b.state.PreviousEpochAttestations)
} else {
atts = b.PreviousEpochAttestations()
}
b.sharedFieldReferences[previousEpochAttestations].MinusRef()
b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1}
}
@@ -620,7 +677,12 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error {
b.lock.RLock()
vals := b.state.Validators
if b.sharedFieldReferences[validators].refs > 1 {
vals = b.Validators()
if featureconfig.Get().EnableStateRefCopy {
vals = make([]*ethpb.Validator, len(b.state.Validators))
copy(vals, b.state.Validators)
} else {
vals = b.Validators()
}
b.sharedFieldReferences[validators].MinusRef()
b.sharedFieldReferences[validators] = &reference{refs: 1}
}

View File

@@ -165,11 +165,11 @@ func (b *BeaconState) Copy() *BeaconState {
memorypool.PutRandaoMixesTrie(b.stateFieldLeaves[randaoMixes].fieldLayers)
}
}
if field == blockRoots && v.refs == 0 && b.stateFieldLeaves[field].refs == 0 {
memorypool.PutBlockRootsTrie(b.stateFieldLeaves[blockRoots].fieldLayers)
if (field == blockRoots || field == stateRoots) && v.refs == 0 && b.stateFieldLeaves[field].refs == 0 {
memorypool.PutRootsTrie(b.stateFieldLeaves[field].fieldLayers)
}
if field == stateRoots && v.refs == 0 && b.stateFieldLeaves[field].refs == 0 {
memorypool.PutStateRootsTrie(b.stateFieldLeaves[stateRoots].fieldLayers)
if field == validators && v.refs == 0 && b.stateFieldLeaves[field].refs == 0 {
memorypool.PutValidatorsTrie(b.stateFieldLeaves[validators].fieldLayers)
}
}
})

View File

@@ -51,6 +51,7 @@ type Flags struct {
EnableFieldTrie bool // EnableFieldTrie enables the state from using field specific tries when computing the root.
EnableBlockHTR bool // EnableBlockHTR enables custom hashing of our beacon blocks.
NoInitSyncBatchSaveBlocks bool // NoInitSyncBatchSaveBlocks disables batch save blocks mode during initial syncing.
EnableStateRefCopy bool // EnableStateRefCopy copies the references to objects instead of the objects themselves when copying state fields.
// DisableForkChoice disables using LMD-GHOST fork choice to update
// the head of the chain based on attestations and instead accepts any valid received block
// as the chain head. UNSAFE, use with caution.
@@ -191,6 +192,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
log.Warn("Disabling init sync batch save blocks mode")
cfg.NoInitSyncBatchSaveBlocks = true
}
if ctx.Bool(enableStateRefCopy.Name) {
log.Warn("Enabling state reference copy")
cfg.EnableStateRefCopy = true
}
Init(cfg)
}

View File

@@ -138,6 +138,10 @@ var (
Name: "disable-init-sync-batch-save-blocks",
Usage: "Instead of saving batch blocks to the DB during initial syncing, this disables batch saving of blocks",
}
enableStateRefCopy = &cli.BoolFlag{
Name: "enable-state-ref-copy",
Usage: "Enables the usage of a new copying method for our state fields.",
}
)
// Deprecated flags list.
@@ -365,6 +369,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
enableFieldTrie,
enableCustomBlockHTR,
disableInitSyncBatchSaveBlocks,
enableStateRefCopy,
}...)
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.
@@ -375,4 +380,5 @@ var E2EBeaconChainFlags = []string{
"--enable-state-gen-sig-verify",
"--check-head-state",
"--enable-state-field-trie",
"--enable-state-ref-copy",
}

View File

@@ -10,18 +10,18 @@ import (
// for 2d byte slices.
var DoubleByteSlicePool = new(sync.Pool)
// BlockRootsMemoryPool represents the memory pool
// for block roots trie.
var BlockRootsMemoryPool = new(sync.Pool)
// StateRootsMemoryPool represents the memory pool
// for state roots trie.
var StateRootsMemoryPool = new(sync.Pool)
// RootsMemoryPool represents the memory pool
// for state roots/block roots trie.
var RootsMemoryPool = new(sync.Pool)
// RandaoMixesMemoryPool represents the memory pool
// for randao mixes trie.
var RandaoMixesMemoryPool = new(sync.Pool)
// ValidatorsMemoryPool represents the memory pool
// for 3d byte slices.
var ValidatorsMemoryPool = new(sync.Pool)
// GetDoubleByteSlice retrieves the 2d byte slice of
// the desired size from the memory pool.
func GetDoubleByteSlice(size int) [][]byte {
@@ -51,14 +51,14 @@ func PutDoubleByteSlice(data [][]byte) {
}
}
// GetBlockRootsTrie retrieves the 3d byte trie of
// GetRootsTrie retrieves the 3d byte trie of
// the desired size from the memory pool.
func GetBlockRootsTrie(size int) [][]*[32]byte {
func GetRootsTrie(size int) [][]*[32]byte {
if !featureconfig.Get().EnableByteMempool {
return make([][]*[32]byte, size)
}
rawObj := BlockRootsMemoryPool.Get()
rawObj := RootsMemoryPool.Get()
if rawObj == nil {
return make([][]*[32]byte, size)
}
@@ -72,40 +72,11 @@ func GetBlockRootsTrie(size int) [][]*[32]byte {
return append(byteSlice, make([][]*[32]byte, size-len(byteSlice))...)
}
// PutBlockRootsTrie places the provided 3d byte trie
// PutRootsTrie places the provided 3d byte trie
// in the memory pool.
func PutBlockRootsTrie(data [][]*[32]byte) {
func PutRootsTrie(data [][]*[32]byte) {
if featureconfig.Get().EnableByteMempool {
BlockRootsMemoryPool.Put(data)
}
}
// GetStateRootsTrie retrieves the 3d byte slice of
// the desired size from the memory pool.
func GetStateRootsTrie(size int) [][]*[32]byte {
if !featureconfig.Get().EnableByteMempool {
return make([][]*[32]byte, size)
}
rawObj := BlockRootsMemoryPool.Get()
if rawObj == nil {
return make([][]*[32]byte, size)
}
byteSlice, ok := rawObj.([][]*[32]byte)
if !ok {
return nil
}
if len(byteSlice) >= size {
return byteSlice[:size]
}
return append(byteSlice, make([][]*[32]byte, size-len(byteSlice))...)
}
// PutStateRootsTrie places the provided trie
// in the memory pool.
func PutStateRootsTrie(data [][]*[32]byte) {
if featureconfig.Get().EnableByteMempool {
StateRootsMemoryPool.Put(data)
RootsMemoryPool.Put(data)
}
}
@@ -116,7 +87,7 @@ func GetRandaoMixesTrie(size int) [][]*[32]byte {
return make([][]*[32]byte, size)
}
rawObj := StateRootsMemoryPool.Get()
rawObj := RandaoMixesMemoryPool.Get()
if rawObj == nil {
return make([][]*[32]byte, size)
}
@@ -134,6 +105,34 @@ func GetRandaoMixesTrie(size int) [][]*[32]byte {
// in the memory pool.
func PutRandaoMixesTrie(data [][]*[32]byte) {
if featureconfig.Get().EnableByteMempool {
StateRootsMemoryPool.Put(data)
RandaoMixesMemoryPool.Put(data)
}
}
// GetValidatorsTrie retrieves the 3d byte slice of
// the desired size from the memory pool.
func GetValidatorsTrie(size int) [][]*[32]byte {
if !featureconfig.Get().EnableByteMempool {
return make([][]*[32]byte, size)
}
rawObj := ValidatorsMemoryPool.Get()
if rawObj == nil {
return make([][]*[32]byte, size)
}
byteSlice, ok := rawObj.([][]*[32]byte)
if !ok {
return nil
}
if len(byteSlice) >= size {
return byteSlice[:size]
}
return append(byteSlice, make([][]*[32]byte, size-len(byteSlice))...)
}
// PutValidatorsTrie places the provided 3d byte slice
// in the memory pool.
func PutValidatorsTrie(data [][]*[32]byte) {
if featureconfig.Get().EnableByteMempool {
ValidatorsMemoryPool.Put(data)
}
}