From d4545233cd45d28f03e5dccb8929ce807f4468a6 Mon Sep 17 00:00:00 2001 From: Preston Van Loon Date: Fri, 19 Jun 2020 17:27:23 -0700 Subject: [PATCH] Fix race condition issues in beacon state (#6322) * Fix race condition issues in beacon state * Merge branch 'master' of github.com:prysmaticlabs/prysm into race-condition-fixes * Checkout beacon-chain/state/stateutil/BUILD.bazel from master * @nisdas PR feedback. defer unlock --- beacon-chain/state/field_trie.go | 10 +++---- beacon-chain/state/setters.go | 46 ++++++++++++++++---------------- beacon-chain/state/state_trie.go | 4 +-- beacon-chain/state/types.go | 11 ++++++++ 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/beacon-chain/state/field_trie.go b/beacon-chain/state/field_trie.go index 1890c7c1d5..580ad33bd7 100644 --- a/beacon-chain/state/field_trie.go +++ b/beacon-chain/state/field_trie.go @@ -28,7 +28,7 @@ func NewFieldTrie(field fieldIndex, elements interface{}, length uint64) (*Field if elements == nil { return &FieldTrie{ field: field, - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), }, nil } @@ -45,14 +45,14 @@ func NewFieldTrie(field fieldIndex, elements interface{}, length uint64) (*Field return &FieldTrie{ fieldLayers: stateutil.ReturnTrieLayer(fieldRoots, length), field: field, - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), }, nil case compositeArray: return &FieldTrie{ fieldLayers: stateutil.ReturnTrieLayerVariable(fieldRoots, length), field: field, - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), }, nil default: @@ -104,7 +104,7 @@ func (f *FieldTrie) CopyTrie() *FieldTrie { if f.fieldLayers == nil { return &FieldTrie{ field: f.field, - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), } } @@ -121,7 +121,7 @@ func (f *FieldTrie) CopyTrie() *FieldTrie { return &FieldTrie{ fieldLayers: dstFieldTrie, field: f.field, - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), } } diff --git a/beacon-chain/state/setters.go b/beacon-chain/state/setters.go index 52ad7f94d7..4d386b96d1 100644 --- a/beacon-chain/state/setters.go +++ b/beacon-chain/state/setters.go @@ -85,7 +85,7 @@ func (b *BeaconState) SetBlockRoots(val [][]byte) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[blockRoots].refs-- + b.sharedFieldReferences[blockRoots].MinusRef() b.sharedFieldReferences[blockRoots] = &reference{refs: 1} b.state.BlockRoots = val @@ -105,7 +105,7 @@ func (b *BeaconState) UpdateBlockRootAtIndex(idx uint64, blockRoot [32]byte) err b.lock.RLock() r := b.state.BlockRoots - if ref := b.sharedFieldReferences[blockRoots]; ref.refs > 1 { + if ref := b.sharedFieldReferences[blockRoots]; ref.Refs() > 1 { // Copy on write since this is a shared array. if featureconfig.Get().EnableStateRefCopy { r = make([][]byte, len(b.state.BlockRoots)) @@ -140,7 +140,7 @@ func (b *BeaconState) SetStateRoots(val [][]byte) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[stateRoots].refs-- + b.sharedFieldReferences[stateRoots].MinusRef() b.sharedFieldReferences[stateRoots] = &reference{refs: 1} b.state.StateRoots = val @@ -162,7 +162,7 @@ func (b *BeaconState) UpdateStateRootAtIndex(idx uint64, stateRoot [32]byte) err b.lock.RLock() // Check if we hold the only reference to the shared state roots slice. r := b.state.StateRoots - if ref := b.sharedFieldReferences[stateRoots]; ref.refs > 1 { + if ref := b.sharedFieldReferences[stateRoots]; ref.Refs() > 1 { // Perform a copy since this is a shared reference and we don't want to mutate others. if featureconfig.Get().EnableStateRefCopy { r = make([][]byte, len(b.state.StateRoots)) @@ -197,7 +197,7 @@ func (b *BeaconState) SetHistoricalRoots(val [][]byte) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[historicalRoots].refs-- + b.sharedFieldReferences[historicalRoots].MinusRef() b.sharedFieldReferences[historicalRoots] = &reference{refs: 1} b.state.HistoricalRoots = val @@ -227,7 +227,7 @@ func (b *BeaconState) SetEth1DataVotes(val []*ethpb.Eth1Data) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[eth1DataVotes].refs-- + b.sharedFieldReferences[eth1DataVotes].MinusRef() b.sharedFieldReferences[eth1DataVotes] = &reference{refs: 1} b.state.Eth1DataVotes = val @@ -244,7 +244,7 @@ func (b *BeaconState) AppendEth1DataVotes(val *ethpb.Eth1Data) error { } b.lock.RLock() votes := b.state.Eth1DataVotes - if b.sharedFieldReferences[eth1DataVotes].refs > 1 { + if b.sharedFieldReferences[eth1DataVotes].Refs() > 1 { if featureconfig.Get().EnableStateRefCopy { votes = make([]*ethpb.Eth1Data, len(b.state.Eth1DataVotes)) copy(votes, b.state.Eth1DataVotes) @@ -288,7 +288,7 @@ func (b *BeaconState) SetValidators(val []*ethpb.Validator) error { defer b.lock.Unlock() b.state.Validators = val - b.sharedFieldReferences[validators].refs-- + b.sharedFieldReferences[validators].MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} b.markFieldAsDirty(validators) b.rebuildTrie[validators] = true @@ -304,7 +304,7 @@ func (b *BeaconState) ApplyToEveryValidator(f func(idx int, val *ethpb.Validator } b.lock.RLock() v := b.state.Validators - if ref := b.sharedFieldReferences[validators]; ref.refs > 1 { + if ref := b.sharedFieldReferences[validators]; ref.Refs() > 1 { // Perform a copy since this is a shared reference and we don't want to mutate others. v = b.Validators() @@ -345,7 +345,7 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e b.lock.RLock() v := b.state.Validators - if ref := b.sharedFieldReferences[validators]; ref.refs > 1 { + if ref := b.sharedFieldReferences[validators]; ref.Refs() > 1 { // Perform a copy since this is a shared reference and we don't want to mutate others. v = b.Validators() @@ -387,7 +387,7 @@ func (b *BeaconState) SetBalances(val []uint64) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[balances].refs-- + b.sharedFieldReferences[balances].MinusRef() b.sharedFieldReferences[balances] = &reference{refs: 1} b.state.Balances = val @@ -407,7 +407,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx uint64, val uint64) error { b.lock.RLock() bals := b.state.Balances - if b.sharedFieldReferences[balances].refs > 1 { + if b.sharedFieldReferences[balances].Refs() > 1 { bals = b.Balances() b.sharedFieldReferences[balances].MinusRef() b.sharedFieldReferences[balances] = &reference{refs: 1} @@ -432,7 +432,7 @@ func (b *BeaconState) SetRandaoMixes(val [][]byte) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[randaoMixes].refs-- + b.sharedFieldReferences[randaoMixes].MinusRef() b.sharedFieldReferences[randaoMixes] = &reference{refs: 1} b.state.RandaoMixes = val @@ -453,7 +453,7 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(idx uint64, val []byte) error { b.lock.RLock() mixes := b.state.RandaoMixes - if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 { + if refs := b.sharedFieldReferences[randaoMixes].Refs(); refs > 1 { if featureconfig.Get().EnableStateRefCopy { mixes = make([][]byte, len(b.state.RandaoMixes)) copy(mixes, b.state.RandaoMixes) @@ -485,7 +485,7 @@ func (b *BeaconState) SetSlashings(val []uint64) error { b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[slashings].refs-- + b.sharedFieldReferences[slashings].MinusRef() b.sharedFieldReferences[slashings] = &reference{refs: 1} b.state.Slashings = val @@ -505,7 +505,7 @@ func (b *BeaconState) UpdateSlashingsAtIndex(idx uint64, val uint64) error { b.lock.RLock() s := b.state.Slashings - if b.sharedFieldReferences[slashings].refs > 1 { + if b.sharedFieldReferences[slashings].Refs() > 1 { s = b.Slashings() b.sharedFieldReferences[slashings].MinusRef() b.sharedFieldReferences[slashings] = &reference{refs: 1} @@ -532,7 +532,7 @@ func (b *BeaconState) SetPreviousEpochAttestations(val []*pbp2p.PendingAttestati b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[previousEpochAttestations].refs-- + b.sharedFieldReferences[previousEpochAttestations].MinusRef() b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1} b.state.PreviousEpochAttestations = val @@ -550,7 +550,7 @@ func (b *BeaconState) SetCurrentEpochAttestations(val []*pbp2p.PendingAttestatio b.lock.Lock() defer b.lock.Unlock() - b.sharedFieldReferences[currentEpochAttestations].refs-- + b.sharedFieldReferences[currentEpochAttestations].MinusRef() b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1} b.state.CurrentEpochAttestations = val @@ -567,7 +567,7 @@ func (b *BeaconState) AppendHistoricalRoots(root [32]byte) error { } b.lock.RLock() roots := b.state.HistoricalRoots - if b.sharedFieldReferences[historicalRoots].refs > 1 { + if b.sharedFieldReferences[historicalRoots].Refs() > 1 { if featureconfig.Get().EnableStateRefCopy { roots = make([][]byte, len(b.state.HistoricalRoots)) copy(roots, b.state.HistoricalRoots) @@ -596,7 +596,7 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati b.lock.RLock() atts := b.state.CurrentEpochAttestations - if b.sharedFieldReferences[currentEpochAttestations].refs > 1 { + if b.sharedFieldReferences[currentEpochAttestations].Refs() > 1 { if featureconfig.Get().EnableStateRefCopy { atts = make([]*pbp2p.PendingAttestation, len(b.state.CurrentEpochAttestations)) copy(atts, b.state.CurrentEpochAttestations) @@ -625,7 +625,7 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat } b.lock.RLock() atts := b.state.PreviousEpochAttestations - if b.sharedFieldReferences[previousEpochAttestations].refs > 1 { + if b.sharedFieldReferences[previousEpochAttestations].Refs() > 1 { if featureconfig.Get().EnableStateRefCopy { atts = make([]*pbp2p.PendingAttestation, len(b.state.PreviousEpochAttestations)) copy(atts, b.state.PreviousEpochAttestations) @@ -655,7 +655,7 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error { } b.lock.RLock() vals := b.state.Validators - if b.sharedFieldReferences[validators].refs > 1 { + if b.sharedFieldReferences[validators].Refs() > 1 { vals = b.Validators() b.sharedFieldReferences[validators].MinusRef() b.sharedFieldReferences[validators] = &reference{refs: 1} @@ -686,7 +686,7 @@ func (b *BeaconState) AppendBalance(bal uint64) error { b.lock.RLock() bals := b.state.Balances - if b.sharedFieldReferences[balances].refs > 1 { + if b.sharedFieldReferences[balances].Refs() > 1 { bals = b.Balances() b.sharedFieldReferences[balances].MinusRef() b.sharedFieldReferences[balances] = &reference{refs: 1} diff --git a/beacon-chain/state/state_trie.go b/beacon-chain/state/state_trie.go index cbdcae7891..7335f37fc7 100644 --- a/beacon-chain/state/state_trie.go +++ b/beacon-chain/state/state_trie.go @@ -48,7 +48,7 @@ func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) { b.dirtyIndices[fieldIndex(i)] = []uint64{} b.stateFieldLeaves[fieldIndex(i)] = &FieldTrie{ field: fieldIndex(i), - reference: &reference{1}, + reference: &reference{refs: 1}, Mutex: new(sync.Mutex), } } @@ -158,7 +158,7 @@ func (b *BeaconState) Copy() *BeaconState { // Finalizer runs when dst is being destroyed in garbage collection. runtime.SetFinalizer(dst, func(b *BeaconState) { for field, v := range b.sharedFieldReferences { - v.refs-- + v.MinusRef() if b.stateFieldLeaves[field].reference != nil { b.stateFieldLeaves[field].MinusRef() } diff --git a/beacon-chain/state/types.go b/beacon-chain/state/types.go index 0f657f9d5f..42671b39c4 100644 --- a/beacon-chain/state/types.go +++ b/beacon-chain/state/types.go @@ -77,6 +77,7 @@ var fieldMap map[fieldIndex]dataType // copy is performed then the state must increment the refs counter. type reference struct { refs uint + lock sync.RWMutex } // ErrNilInnerState returns when the inner state is nil and no copy set or get @@ -103,11 +104,21 @@ type ReadOnlyValidator struct { validator *ethpb.Validator } +func (r *reference) Refs() uint { + r.lock.RLock() + defer r.lock.RUnlock() + return r.refs +} + func (r *reference) AddRef() { + r.lock.Lock() r.refs++ + r.lock.Unlock() } func (r *reference) MinusRef() { + r.lock.Lock() + defer r.lock.Unlock() // Do not reduce further if object // already has 0 reference to prevent overflow. if r.refs == 0 {