Fix electra state to safe share references on pending fields when append (#14895)

* Fix electra state to safe share references on pending fields when append

* Feedback
This commit is contained in:
terence
2025-02-07 19:04:02 -08:00
committed by GitHub
parent 6b3f1de19d
commit 81a2a17c5f
10 changed files with 84 additions and 27 deletions

View File

@@ -71,9 +71,7 @@ func (b *BeaconState) ExitEpochAndUpdateChurn(exitBalance primitives.Gwei) (prim
b.earliestExitEpoch = earliestExitEpoch
b.markFieldAsDirty(types.ExitBalanceToConsume)
b.rebuildTrie[types.ExitBalanceToConsume] = true
b.markFieldAsDirty(types.EarliestExitEpoch)
b.rebuildTrie[types.EarliestExitEpoch] = true
return b.earliestExitEpoch, nil
}

View File

@@ -23,13 +23,17 @@ func (b *BeaconState) AppendPendingConsolidation(val *ethpb.PendingConsolidation
b.lock.Lock()
defer b.lock.Unlock()
b.sharedFieldReferences[types.PendingConsolidations].MinusRef()
b.sharedFieldReferences[types.PendingConsolidations] = stateutil.NewRef(1)
b.pendingConsolidations = append(b.pendingConsolidations, val)
pendingConsolidations := b.pendingConsolidations
if b.sharedFieldReferences[types.PendingConsolidations].Refs() > 1 {
pendingConsolidations = make([]*ethpb.PendingConsolidation, 0, len(b.pendingConsolidations)+1)
pendingConsolidations = append(pendingConsolidations, b.pendingConsolidations...)
b.sharedFieldReferences[types.PendingConsolidations].MinusRef()
b.sharedFieldReferences[types.PendingConsolidations] = stateutil.NewRef(1)
}
b.pendingConsolidations = append(pendingConsolidations, val)
b.markFieldAsDirty(types.PendingConsolidations)
b.rebuildTrie[types.PendingConsolidations] = true
return nil
}
@@ -66,7 +70,6 @@ func (b *BeaconState) SetEarliestConsolidationEpoch(epoch primitives.Epoch) erro
b.earliestConsolidationEpoch = epoch
b.markFieldAsDirty(types.EarliestConsolidationEpoch)
b.rebuildTrie[types.EarliestConsolidationEpoch] = true
return nil
}
@@ -83,6 +86,5 @@ func (b *BeaconState) SetConsolidationBalanceToConsume(balance primitives.Gwei)
b.consolidationBalanceToConsume = balance
b.markFieldAsDirty(types.ConsolidationBalanceToConsume)
b.rebuildTrie[types.ConsolidationBalanceToConsume] = true
return nil
}

View File

@@ -20,6 +20,21 @@ func TestAppendPendingConsolidation(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(1), num)
pc := make([]*eth.PendingConsolidation, 0, 4)
require.NoError(t, s.SetPendingConsolidations(pc))
require.NoError(t, s.AppendPendingConsolidation(&eth.PendingConsolidation{SourceIndex: 1}))
s2 := s.Copy()
require.NoError(t, s2.AppendPendingConsolidation(&eth.PendingConsolidation{SourceIndex: 3}))
require.NoError(t, s.AppendPendingConsolidation(&eth.PendingConsolidation{SourceIndex: 2}))
pc, err = s.PendingConsolidations()
require.NoError(t, err)
require.Equal(t, primitives.ValidatorIndex(1), pc[0].SourceIndex)
require.Equal(t, primitives.ValidatorIndex(2), pc[1].SourceIndex)
pc, err = s2.PendingConsolidations()
require.NoError(t, err)
require.Equal(t, primitives.ValidatorIndex(1), pc[0].SourceIndex)
require.Equal(t, primitives.ValidatorIndex(3), pc[1].SourceIndex)
// Fails for versions older than electra
s, err = state_native.InitializeFromProtoDeneb(&eth.BeaconStateDeneb{})
require.NoError(t, err)

View File

@@ -16,6 +16,5 @@ func (b *BeaconState) SetDepositRequestsStartIndex(index uint64) error {
b.depositRequestsStartIndex = index
b.markFieldAsDirty(types.DepositRequestsStartIndex)
b.rebuildTrie[types.DepositRequestsStartIndex] = true
return nil
}

View File

@@ -23,13 +23,17 @@ func (b *BeaconState) AppendPendingDeposit(pd *ethpb.PendingDeposit) error {
b.lock.Lock()
defer b.lock.Unlock()
b.sharedFieldReferences[types.PendingDeposits].MinusRef()
b.sharedFieldReferences[types.PendingDeposits] = stateutil.NewRef(1)
b.pendingDeposits = append(b.pendingDeposits, pd)
pendingDeposits := b.pendingDeposits
if b.sharedFieldReferences[types.PendingDeposits].Refs() > 1 {
pendingDeposits = make([]*ethpb.PendingDeposit, 0, len(b.pendingDeposits)+1)
pendingDeposits = append(pendingDeposits, b.pendingDeposits...)
b.sharedFieldReferences[types.PendingDeposits].MinusRef()
b.sharedFieldReferences[types.PendingDeposits] = stateutil.NewRef(1)
}
b.pendingDeposits = append(pendingDeposits, pd)
b.markFieldAsDirty(types.PendingDeposits)
b.rebuildTrie[types.PendingDeposits] = true
return nil
}
@@ -66,6 +70,5 @@ func (b *BeaconState) SetDepositBalanceToConsume(dbtc primitives.Gwei) error {
b.depositBalanceToConsume = dbtc
b.markFieldAsDirty(types.DepositBalanceToConsume)
b.rebuildTrie[types.DepositBalanceToConsume] = true
return nil
}

View File

@@ -34,6 +34,21 @@ func TestAppendPendingDeposit(t *testing.T) {
require.Equal(t, primitives.Slot(1), pbd[0].Slot)
require.DeepEqual(t, sig, pbd[0].Signature)
ds := make([]*eth.PendingDeposit, 0, 4)
require.NoError(t, s.SetPendingDeposits(ds))
require.NoError(t, s.AppendPendingDeposit(&eth.PendingDeposit{Amount: 1}))
s2 := s.Copy()
require.NoError(t, s2.AppendPendingDeposit(&eth.PendingDeposit{Amount: 3}))
require.NoError(t, s.AppendPendingDeposit(&eth.PendingDeposit{Amount: 2}))
d, err := s.PendingDeposits()
require.NoError(t, err)
require.Equal(t, uint64(1), d[0].Amount)
require.Equal(t, uint64(2), d[1].Amount)
d, err = s2.PendingDeposits()
require.NoError(t, err)
require.Equal(t, uint64(1), d[0].Amount)
require.Equal(t, uint64(3), d[1].Amount)
// Fails for versions older than electra
s, err = state_native.InitializeFromProtoDeneb(&eth.BeaconStateDeneb{})
require.NoError(t, err)

View File

@@ -54,13 +54,17 @@ func (b *BeaconState) AppendPendingPartialWithdrawal(ppw *eth.PendingPartialWith
b.lock.Lock()
defer b.lock.Unlock()
b.sharedFieldReferences[types.PendingPartialWithdrawals].MinusRef()
b.sharedFieldReferences[types.PendingPartialWithdrawals] = stateutil.NewRef(1)
b.pendingPartialWithdrawals = append(b.pendingPartialWithdrawals, ppw)
pendingPartialWithdrawals := b.pendingPartialWithdrawals
if b.sharedFieldReferences[types.PendingPartialWithdrawals].Refs() > 1 {
pendingPartialWithdrawals = make([]*eth.PendingPartialWithdrawal, 0, len(b.pendingPartialWithdrawals)+1)
pendingPartialWithdrawals = append(pendingPartialWithdrawals, b.pendingPartialWithdrawals...)
b.sharedFieldReferences[types.PendingPartialWithdrawals].MinusRef()
b.sharedFieldReferences[types.PendingPartialWithdrawals] = stateutil.NewRef(1)
}
b.pendingPartialWithdrawals = append(pendingPartialWithdrawals, ppw)
b.markFieldAsDirty(types.PendingPartialWithdrawals)
b.rebuildTrie[types.PendingPartialWithdrawals] = true
return nil
}
@@ -81,8 +85,13 @@ func (b *BeaconState) DequeuePendingPartialWithdrawals(n uint64) error {
b.lock.Lock()
defer b.lock.Unlock()
b.sharedFieldReferences[types.PendingPartialWithdrawals].MinusRef()
b.sharedFieldReferences[types.PendingPartialWithdrawals] = stateutil.NewRef(1)
if b.sharedFieldReferences[types.PendingPartialWithdrawals].Refs() > 1 {
pendingPartialWithdrawals := make([]*eth.PendingPartialWithdrawal, len(b.pendingPartialWithdrawals))
copy(pendingPartialWithdrawals, b.pendingPartialWithdrawals)
b.pendingPartialWithdrawals = pendingPartialWithdrawals
b.sharedFieldReferences[types.PendingPartialWithdrawals].MinusRef()
b.sharedFieldReferences[types.PendingPartialWithdrawals] = stateutil.NewRef(1)
}
b.pendingPartialWithdrawals = b.pendingPartialWithdrawals[n:]

View File

@@ -68,15 +68,16 @@ func TestDequeuePendingWithdrawals(t *testing.T) {
num, err := s.NumPendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, uint64(3), num)
s2 := s.Copy()
require.NoError(t, s.DequeuePendingPartialWithdrawals(2))
num, err = s.NumPendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, uint64(1), num)
num, err = s2.NumPendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, uint64(3), num)
// 2 of 1 exceeds the limit and an error should be returned
num, err = s.NumPendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, uint64(1), num)
require.ErrorContains(t, "cannot dequeue more withdrawals than are in the queue", s.DequeuePendingPartialWithdrawals(2))
// Removing all pending partial withdrawals should be OK.
@@ -111,6 +112,19 @@ func TestAppendPendingWithdrawals(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(4), num)
require.NoError(t, s.AppendPendingPartialWithdrawal(&eth.PendingPartialWithdrawal{Index: 1}))
s2 := s.Copy()
require.NoError(t, s2.AppendPendingPartialWithdrawal(&eth.PendingPartialWithdrawal{Index: 3}))
require.NoError(t, s.AppendPendingPartialWithdrawal(&eth.PendingPartialWithdrawal{Index: 2}))
w, err := s.PendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, primitives.ValidatorIndex(1), w[4].Index)
require.Equal(t, primitives.ValidatorIndex(2), w[5].Index)
w, err = s2.PendingPartialWithdrawals()
require.NoError(t, err)
require.Equal(t, primitives.ValidatorIndex(1), w[4].Index)
require.Equal(t, primitives.ValidatorIndex(3), w[5].Index)
require.ErrorContains(t, "cannot append nil pending partial withdrawal", s.AppendPendingPartialWithdrawal(nil))
s, err = InitializeFromProtoDeneb(&eth.BeaconStateDeneb{})

View File

@@ -820,7 +820,7 @@ func InitializeFromProtoUnsafeElectra(st *ethpb.BeaconStateElectra) (state.Beaco
b.sharedFieldReferences[types.Slashings] = stateutil.NewRef(1)
b.sharedFieldReferences[types.PreviousEpochParticipationBits] = stateutil.NewRef(1)
b.sharedFieldReferences[types.CurrentEpochParticipationBits] = stateutil.NewRef(1)
b.sharedFieldReferences[types.LatestExecutionPayloadHeaderDeneb] = stateutil.NewRef(1) // New in Electra.
b.sharedFieldReferences[types.LatestExecutionPayloadHeaderDeneb] = stateutil.NewRef(1)
b.sharedFieldReferences[types.HistoricalSummaries] = stateutil.NewRef(1)
b.sharedFieldReferences[types.PendingDeposits] = stateutil.NewRef(1) // New in Electra.
b.sharedFieldReferences[types.PendingPartialWithdrawals] = stateutil.NewRef(1) // New in Electra.