From 4de92bafc4bb0051381fba48eb36830aa181d7ee Mon Sep 17 00:00:00 2001 From: Nishant Das Date: Thu, 16 Jun 2022 21:14:29 +0800 Subject: [PATCH] Improve Field Trie Recomputation (#10884) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Raul Jordan Co-authored-by: RadosÅ‚aw Kapka --- beacon-chain/state/fieldtrie/field_trie.go | 52 ++++- beacon-chain/state/state-native/state_trie.go | 24 ++- .../state/state-native/state_trie_test.go | 201 ++++++++++++++++++ beacon-chain/state/v1/state_trie.go | 24 ++- beacon-chain/state/v1/state_trie_test.go | 67 ++++++ beacon-chain/state/v2/BUILD.bazel | 1 + beacon-chain/state/v2/state_test.go | 79 +++++++ beacon-chain/state/v2/state_trie.go | 24 ++- beacon-chain/state/v3/BUILD.bazel | 1 + beacon-chain/state/v3/state_test.go | 79 +++++++ beacon-chain/state/v3/state_trie.go | 25 ++- 11 files changed, 551 insertions(+), 26 deletions(-) create mode 100644 beacon-chain/state/v2/state_test.go create mode 100644 beacon-chain/state/v3/state_test.go diff --git a/beacon-chain/state/fieldtrie/field_trie.go b/beacon-chain/state/fieldtrie/field_trie.go index d578109374..4dcf7a30ac 100644 --- a/beacon-chain/state/fieldtrie/field_trie.go +++ b/beacon-chain/state/fieldtrie/field_trie.go @@ -19,12 +19,13 @@ var ( // trie of the particular field. type FieldTrie struct { *sync.RWMutex - reference *stateutil.Reference - fieldLayers [][]*[32]byte - field types.BeaconStateField - dataType types.DataType - length uint64 - numOfElems int + reference *stateutil.Reference + fieldLayers [][]*[32]byte + field types.BeaconStateField + dataType types.DataType + length uint64 + numOfElems int + isTransferred bool } // NewFieldTrie is the constructor for the field trie data structure. It creates the corresponding @@ -191,6 +192,43 @@ func (f *FieldTrie) CopyTrie() *FieldTrie { } } +// Length return the length of the whole field trie. +func (f *FieldTrie) Length() uint64 { + return f.length +} + +// TransferTrie starts the process of transferring all the +// trie related data to a new trie. This is done if we +// know that other states which hold references to this +// trie will unlikely need it for recomputation. This helps +// us save on a copy. Any caller of this method will need +// to take care that this isn't called on an empty trie. +func (f *FieldTrie) TransferTrie() *FieldTrie { + if f.fieldLayers == nil { + return &FieldTrie{ + field: f.field, + dataType: f.dataType, + reference: stateutil.NewRef(1), + RWMutex: new(sync.RWMutex), + length: f.length, + numOfElems: f.numOfElems, + } + } + f.isTransferred = true + nTrie := &FieldTrie{ + fieldLayers: f.fieldLayers, + field: f.field, + dataType: f.dataType, + reference: stateutil.NewRef(1), + RWMutex: new(sync.RWMutex), + length: f.length, + numOfElems: f.numOfElems, + } + // Zero out field layers here. + f.fieldLayers = nil + return nTrie +} + // TrieRoot returns the corresponding root of the trie. func (f *FieldTrie) TrieRoot() ([32]byte, error) { if f.Empty() { @@ -222,7 +260,7 @@ func (f *FieldTrie) FieldReference() *stateutil.Reference { // Empty checks whether the underlying field trie is // empty or not. func (f *FieldTrie) Empty() bool { - return f == nil || len(f.fieldLayers) == 0 + return f == nil || len(f.fieldLayers) == 0 || f.isTransferred } // InsertFieldLayer manually inserts a field layer. This method diff --git a/beacon-chain/state/state-native/state_trie.go b/beacon-chain/state/state-native/state_trie.go index 92ed2cdb95..4dd881d483 100644 --- a/beacon-chain/state/state-native/state_trie.go +++ b/beacon-chain/state/state-native/state_trie.go @@ -740,17 +740,31 @@ func (b *BeaconState) rootSelector(ctx context.Context, field nativetypes.FieldI func (b *BeaconState) recomputeFieldTrie(index nativetypes.FieldIndex, elements interface{}) ([32]byte, error) { fTrie := b.stateFieldLeaves[index] + fTrieMutex := fTrie.RWMutex // We can't lock the trie directly because the trie's variable gets reassigned, // and therefore we would call Unlock() on a different object. - fTrieMutex := fTrie.RWMutex - if fTrie.FieldReference().Refs() > 1 { - fTrieMutex.Lock() + fTrieMutex.Lock() + + if fTrie.Empty() { + err := b.resetFieldTrie(index, elements, fTrie.Length()) + if err != nil { + fTrieMutex.Unlock() + return [32]byte{}, err + } + // Reduce reference count as we are instantiating a new trie. fTrie.FieldReference().MinusRef() - newTrie := fTrie.CopyTrie() + fTrieMutex.Unlock() + return b.stateFieldLeaves[index].TrieRoot() + } + + if fTrie.FieldReference().Refs() > 1 { + fTrie.FieldReference().MinusRef() + newTrie := fTrie.TransferTrie() b.stateFieldLeaves[index] = newTrie fTrie = newTrie - fTrieMutex.Unlock() } + fTrieMutex.Unlock() + // remove duplicate indexes b.dirtyIndices[index] = slice.SetUint64(b.dirtyIndices[index]) // sort indexes again diff --git a/beacon-chain/state/state-native/state_trie_test.go b/beacon-chain/state/state-native/state_trie_test.go index ab23124ba4..3f901470ec 100644 --- a/beacon-chain/state/state-native/state_trie_test.go +++ b/beacon-chain/state/state-native/state_trie_test.go @@ -383,3 +383,204 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) { _, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey)) assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey") } + +func TestBeaconState_ValidatorMutation_Phase0(t *testing.T) { + testState, _ := util.DeterministicGenesisState(t, 400) + pbState, err := statenative.ProtobufBeaconStatePhase0(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = statenative.InitializeFromProtoPhase0(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStatePhase0(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := statenative.InitializeFromProtoPhase0(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStatePhase0(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = statenative.InitializeFromProtoPhase0(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} + +func TestBeaconState_ValidatorMutation_Altair(t *testing.T) { + testState, _ := util.DeterministicGenesisStateAltair(t, 400) + pbState, err := statenative.ProtobufBeaconStateAltair(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = statenative.InitializeFromProtoAltair(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStateAltair(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := statenative.InitializeFromProtoAltair(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStateAltair(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = statenative.InitializeFromProtoAltair(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} + +func TestBeaconState_ValidatorMutation_Bellatrix(t *testing.T) { + testState, _ := util.DeterministicGenesisStateBellatrix(t, 400) + pbState, err := statenative.ProtobufBeaconStateBellatrix(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = statenative.InitializeFromProtoBellatrix(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStateBellatrix(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := statenative.InitializeFromProtoBellatrix(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = statenative.ProtobufBeaconStateBellatrix(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = statenative.InitializeFromProtoBellatrix(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} diff --git a/beacon-chain/state/v1/state_trie.go b/beacon-chain/state/v1/state_trie.go index c4271ee67e..5e248bf387 100644 --- a/beacon-chain/state/v1/state_trie.go +++ b/beacon-chain/state/v1/state_trie.go @@ -390,17 +390,31 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex) func (b *BeaconState) recomputeFieldTrie(index types.FieldIndex, elements interface{}) ([32]byte, error) { fTrie := b.stateFieldLeaves[index] + fTrieMutex := fTrie.RWMutex // We can't lock the trie directly because the trie's variable gets reassigned, // and therefore we would call Unlock() on a different object. - fTrieMutex := fTrie.RWMutex - if fTrie.FieldReference().Refs() > 1 { - fTrieMutex.Lock() + fTrieMutex.Lock() + + if fTrie.Empty() { + err := b.resetFieldTrie(index, elements, fTrie.Length()) + if err != nil { + fTrieMutex.Unlock() + return [32]byte{}, err + } + // Reduce reference count as we are instantiating a new trie. fTrie.FieldReference().MinusRef() - newTrie := fTrie.CopyTrie() + fTrieMutex.Unlock() + return b.stateFieldLeaves[index].TrieRoot() + } + + if fTrie.FieldReference().Refs() > 1 { + fTrie.FieldReference().MinusRef() + newTrie := fTrie.TransferTrie() b.stateFieldLeaves[index] = newTrie fTrie = newTrie - fTrieMutex.Unlock() } + fTrieMutex.Unlock() + // remove duplicate indexes b.dirtyIndices[index] = slice.SetUint64(b.dirtyIndices[index]) // sort indexes again diff --git a/beacon-chain/state/v1/state_trie_test.go b/beacon-chain/state/v1/state_trie_test.go index 478eedc4b8..996b076a1a 100644 --- a/beacon-chain/state/v1/state_trie_test.go +++ b/beacon-chain/state/v1/state_trie_test.go @@ -269,3 +269,70 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) { _, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey)) assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey") } + +func TestBeaconState_ValidatorMutation_Phase0(t *testing.T) { + testState, _ := util.DeterministicGenesisState(t, 400) + pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = v1.InitializeFromProto(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v1.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := v1.InitializeFromProto(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v1.ProtobufBeaconState(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = v1.InitializeFromProto(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} diff --git a/beacon-chain/state/v2/BUILD.bazel b/beacon-chain/state/v2/BUILD.bazel index d92b4cce94..fc6339ad03 100644 --- a/beacon-chain/state/v2/BUILD.bazel +++ b/beacon-chain/state/v2/BUILD.bazel @@ -71,6 +71,7 @@ go_test( "references_test.go", "setters_test.go", "state_fuzz_test.go", + "state_test.go", "state_trie_test.go", ], data = glob(["testdata/**"]), diff --git a/beacon-chain/state/v2/state_test.go b/beacon-chain/state/v2/state_test.go new file mode 100644 index 0000000000..fea327b834 --- /dev/null +++ b/beacon-chain/state/v2/state_test.go @@ -0,0 +1,79 @@ +package v2_test + +import ( + "context" + "testing" + + v2 "github.com/prysmaticlabs/prysm/beacon-chain/state/v2" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" + "github.com/prysmaticlabs/prysm/testing/require" + "github.com/prysmaticlabs/prysm/testing/util" +) + +func TestBeaconState_ValidatorMutation_Altair(t *testing.T) { + testState, _ := util.DeterministicGenesisStateAltair(t, 400) + pbState, err := v2.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = v2.InitializeFromProto(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v2.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := v2.InitializeFromProto(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v2.ProtobufBeaconState(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = v2.InitializeFromProto(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} diff --git a/beacon-chain/state/v2/state_trie.go b/beacon-chain/state/v2/state_trie.go index 8b15a2d3a3..d8e79ddedd 100644 --- a/beacon-chain/state/v2/state_trie.go +++ b/beacon-chain/state/v2/state_trie.go @@ -378,17 +378,31 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex) func (b *BeaconState) recomputeFieldTrie(index types.FieldIndex, elements interface{}) ([32]byte, error) { fTrie := b.stateFieldLeaves[index] + fTrieMutex := fTrie.RWMutex // We can't lock the trie directly because the trie's variable gets reassigned, // and therefore we would call Unlock() on a different object. - fTrieMutex := fTrie.RWMutex - if fTrie.FieldReference().Refs() > 1 { - fTrieMutex.Lock() + fTrieMutex.Lock() + + if fTrie.Empty() { + err := b.resetFieldTrie(index, elements, fTrie.Length()) + if err != nil { + fTrieMutex.Unlock() + return [32]byte{}, err + } + // Reduce reference count as we are instantiating a new trie. fTrie.FieldReference().MinusRef() - newTrie := fTrie.CopyTrie() + fTrieMutex.Unlock() + return b.stateFieldLeaves[index].TrieRoot() + } + + if fTrie.FieldReference().Refs() > 1 { + fTrie.FieldReference().MinusRef() + newTrie := fTrie.TransferTrie() b.stateFieldLeaves[index] = newTrie fTrie = newTrie - fTrieMutex.Unlock() } + fTrieMutex.Unlock() + // remove duplicate indexes b.dirtyIndices[index] = slice.SetUint64(b.dirtyIndices[index]) // sort indexes again diff --git a/beacon-chain/state/v3/BUILD.bazel b/beacon-chain/state/v3/BUILD.bazel index 973cc0a9c1..700e718c6c 100644 --- a/beacon-chain/state/v3/BUILD.bazel +++ b/beacon-chain/state/v3/BUILD.bazel @@ -73,6 +73,7 @@ go_test( "references_test.go", "setters_test.go", "state_fuzz_test.go", + "state_test.go", "state_trie_test.go", ], embed = [":go_default_library"], diff --git a/beacon-chain/state/v3/state_test.go b/beacon-chain/state/v3/state_test.go new file mode 100644 index 0000000000..99805092b6 --- /dev/null +++ b/beacon-chain/state/v3/state_test.go @@ -0,0 +1,79 @@ +package v3_test + +import ( + "context" + "testing" + + v3 "github.com/prysmaticlabs/prysm/beacon-chain/state/v3" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" + "github.com/prysmaticlabs/prysm/testing/require" + "github.com/prysmaticlabs/prysm/testing/util" +) + +func TestBeaconState_ValidatorMutation_Bellatrix(t *testing.T) { + testState, _ := util.DeterministicGenesisStateBellatrix(t, 400) + pbState, err := v3.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + testState, err = v3.InitializeFromProto(pbState) + require.NoError(t, err) + + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + // Reset tries + require.NoError(t, testState.UpdateValidatorAtIndex(200, new(ethpb.Validator))) + _, err = testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + newState1 := testState.Copy() + _ = testState.Copy() + + require.NoError(t, testState.UpdateValidatorAtIndex(15, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 1111, + Slashed: false, + ActivationEligibilityEpoch: 1112, + ActivationEpoch: 1114, + ExitEpoch: 1116, + WithdrawableEpoch: 1117, + })) + + rt, err := testState.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v3.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err := v3.InitializeFromProtoUnsafe(pbState) + require.NoError(t, err) + + rt2, err := copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) + + require.NoError(t, newState1.UpdateValidatorAtIndex(150, ðpb.Validator{ + PublicKey: make([]byte, 48), + WithdrawalCredentials: make([]byte, 32), + EffectiveBalance: 2111, + Slashed: false, + ActivationEligibilityEpoch: 2112, + ActivationEpoch: 2114, + ExitEpoch: 2116, + WithdrawableEpoch: 2117, + })) + + rt, err = newState1.HashTreeRoot(context.Background()) + require.NoError(t, err) + pbState, err = v3.ProtobufBeaconState(newState1.InnerStateUnsafe()) + require.NoError(t, err) + + copiedTestState, err = v3.InitializeFromProto(pbState) + require.NoError(t, err) + + rt2, err = copiedTestState.HashTreeRoot(context.Background()) + require.NoError(t, err) + + assert.Equal(t, rt, rt2) +} diff --git a/beacon-chain/state/v3/state_trie.go b/beacon-chain/state/v3/state_trie.go index 8c367cedec..5f0abe7a29 100644 --- a/beacon-chain/state/v3/state_trie.go +++ b/beacon-chain/state/v3/state_trie.go @@ -376,14 +376,31 @@ func (b *BeaconState) rootSelector(field types.FieldIndex) ([32]byte, error) { func (b *BeaconState) recomputeFieldTrie(index types.FieldIndex, elements interface{}) ([32]byte, error) { fTrie := b.stateFieldLeaves[index] - if fTrie.FieldReference().Refs() > 1 { - fTrie.Lock() - defer fTrie.Unlock() + fTrieMutex := fTrie.RWMutex + // We can't lock the trie directly because the trie's variable gets reassigned, + // and therefore we would call Unlock() on a different object. + fTrieMutex.Lock() + + if fTrie.Empty() { + err := b.resetFieldTrie(index, elements, fTrie.Length()) + if err != nil { + fTrieMutex.Unlock() + return [32]byte{}, err + } + // Reduce reference count as we are instantiating a new trie. fTrie.FieldReference().MinusRef() - newTrie := fTrie.CopyTrie() + fTrieMutex.Unlock() + return b.stateFieldLeaves[index].TrieRoot() + } + + if fTrie.FieldReference().Refs() > 1 { + fTrie.FieldReference().MinusRef() + newTrie := fTrie.TransferTrie() b.stateFieldLeaves[index] = newTrie fTrie = newTrie } + fTrieMutex.Unlock() + // remove duplicate indexes b.dirtyIndices[index] = slice.SetUint64(b.dirtyIndices[index]) // sort indexes again