mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-10 13:58:09 -05:00
Compare commits
3 Commits
process-ex
...
state-diff
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3aa7998e32 | ||
|
|
f71de7ed1f | ||
|
|
cf76ea4852 |
@@ -12,6 +12,47 @@ import (
|
||||
"github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1/attestation"
|
||||
)
|
||||
|
||||
// ConvertToAltair converts a Phase 0 beacon state to an Altair beacon state.
|
||||
func ConvertToAltair(state state.BeaconState) (state.BeaconState, error) {
|
||||
epoch := time.CurrentEpoch(state)
|
||||
|
||||
numValidators := state.NumValidators()
|
||||
s := ðpb.BeaconStateAltair{
|
||||
GenesisTime: state.GenesisTime(),
|
||||
GenesisValidatorsRoot: state.GenesisValidatorsRoot(),
|
||||
Slot: state.Slot(),
|
||||
Fork: ðpb.Fork{
|
||||
PreviousVersion: state.Fork().CurrentVersion,
|
||||
CurrentVersion: params.BeaconConfig().AltairForkVersion,
|
||||
Epoch: epoch,
|
||||
},
|
||||
LatestBlockHeader: state.LatestBlockHeader(),
|
||||
BlockRoots: state.BlockRoots(),
|
||||
StateRoots: state.StateRoots(),
|
||||
HistoricalRoots: state.HistoricalRoots(),
|
||||
Eth1Data: state.Eth1Data(),
|
||||
Eth1DataVotes: state.Eth1DataVotes(),
|
||||
Eth1DepositIndex: state.Eth1DepositIndex(),
|
||||
Validators: state.Validators(),
|
||||
Balances: state.Balances(),
|
||||
RandaoMixes: state.RandaoMixes(),
|
||||
Slashings: state.Slashings(),
|
||||
PreviousEpochParticipation: make([]byte, numValidators),
|
||||
CurrentEpochParticipation: make([]byte, numValidators),
|
||||
JustificationBits: state.JustificationBits(),
|
||||
PreviousJustifiedCheckpoint: state.PreviousJustifiedCheckpoint(),
|
||||
CurrentJustifiedCheckpoint: state.CurrentJustifiedCheckpoint(),
|
||||
FinalizedCheckpoint: state.FinalizedCheckpoint(),
|
||||
InactivityScores: make([]uint64, numValidators),
|
||||
}
|
||||
|
||||
newState, err := state_native.InitializeFromProtoUnsafeAltair(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newState, nil
|
||||
}
|
||||
|
||||
// UpgradeToAltair updates input state to return the version Altair state.
|
||||
//
|
||||
// Spec code:
|
||||
@@ -64,39 +105,7 @@ import (
|
||||
// post.next_sync_committee = get_next_sync_committee(post)
|
||||
// return post
|
||||
func UpgradeToAltair(ctx context.Context, state state.BeaconState) (state.BeaconState, error) {
|
||||
epoch := time.CurrentEpoch(state)
|
||||
|
||||
numValidators := state.NumValidators()
|
||||
s := ðpb.BeaconStateAltair{
|
||||
GenesisTime: state.GenesisTime(),
|
||||
GenesisValidatorsRoot: state.GenesisValidatorsRoot(),
|
||||
Slot: state.Slot(),
|
||||
Fork: ðpb.Fork{
|
||||
PreviousVersion: state.Fork().CurrentVersion,
|
||||
CurrentVersion: params.BeaconConfig().AltairForkVersion,
|
||||
Epoch: epoch,
|
||||
},
|
||||
LatestBlockHeader: state.LatestBlockHeader(),
|
||||
BlockRoots: state.BlockRoots(),
|
||||
StateRoots: state.StateRoots(),
|
||||
HistoricalRoots: state.HistoricalRoots(),
|
||||
Eth1Data: state.Eth1Data(),
|
||||
Eth1DataVotes: state.Eth1DataVotes(),
|
||||
Eth1DepositIndex: state.Eth1DepositIndex(),
|
||||
Validators: state.Validators(),
|
||||
Balances: state.Balances(),
|
||||
RandaoMixes: state.RandaoMixes(),
|
||||
Slashings: state.Slashings(),
|
||||
PreviousEpochParticipation: make([]byte, numValidators),
|
||||
CurrentEpochParticipation: make([]byte, numValidators),
|
||||
JustificationBits: state.JustificationBits(),
|
||||
PreviousJustifiedCheckpoint: state.PreviousJustifiedCheckpoint(),
|
||||
CurrentJustifiedCheckpoint: state.CurrentJustifiedCheckpoint(),
|
||||
FinalizedCheckpoint: state.FinalizedCheckpoint(),
|
||||
InactivityScores: make([]uint64, numValidators),
|
||||
}
|
||||
|
||||
newState, err := state_native.InitializeFromProtoUnsafeAltair(s)
|
||||
newState, err := ConvertToAltair(state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,129 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ConvertToElectra converts a Deneb beacon state to an Electra beacon state.
|
||||
func ConvertToElectra(beaconState state.BeaconState) (state.BeaconState, error) {
|
||||
currentSyncCommittee, err := beaconState.CurrentSyncCommittee()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nextSyncCommittee, err := beaconState.NextSyncCommittee()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prevEpochParticipation, err := beaconState.PreviousEpochParticipation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currentEpochParticipation, err := beaconState.CurrentEpochParticipation()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inactivityScores, err := beaconState.InactivityScores()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadHeader, err := beaconState.LatestExecutionPayloadHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
txRoot, err := payloadHeader.TransactionsRoot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wdRoot, err := payloadHeader.WithdrawalsRoot()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wi, err := beaconState.NextWithdrawalIndex()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vi, err := beaconState.NextWithdrawalValidatorIndex()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summaries, err := beaconState.HistoricalSummaries()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
excessBlobGas, err := payloadHeader.ExcessBlobGas()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blobGasUsed, err := payloadHeader.BlobGasUsed()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := ðpb.BeaconStateElectra{
|
||||
GenesisTime: beaconState.GenesisTime(),
|
||||
GenesisValidatorsRoot: beaconState.GenesisValidatorsRoot(),
|
||||
Slot: beaconState.Slot(),
|
||||
Fork: ðpb.Fork{
|
||||
PreviousVersion: beaconState.Fork().CurrentVersion,
|
||||
CurrentVersion: params.BeaconConfig().ElectraForkVersion,
|
||||
Epoch: time.CurrentEpoch(beaconState),
|
||||
},
|
||||
LatestBlockHeader: beaconState.LatestBlockHeader(),
|
||||
BlockRoots: beaconState.BlockRoots(),
|
||||
StateRoots: beaconState.StateRoots(),
|
||||
HistoricalRoots: beaconState.HistoricalRoots(),
|
||||
Eth1Data: beaconState.Eth1Data(),
|
||||
Eth1DataVotes: beaconState.Eth1DataVotes(),
|
||||
Eth1DepositIndex: beaconState.Eth1DepositIndex(),
|
||||
Validators: beaconState.Validators(),
|
||||
Balances: beaconState.Balances(),
|
||||
RandaoMixes: beaconState.RandaoMixes(),
|
||||
Slashings: beaconState.Slashings(),
|
||||
PreviousEpochParticipation: prevEpochParticipation,
|
||||
CurrentEpochParticipation: currentEpochParticipation,
|
||||
JustificationBits: beaconState.JustificationBits(),
|
||||
PreviousJustifiedCheckpoint: beaconState.PreviousJustifiedCheckpoint(),
|
||||
CurrentJustifiedCheckpoint: beaconState.CurrentJustifiedCheckpoint(),
|
||||
FinalizedCheckpoint: beaconState.FinalizedCheckpoint(),
|
||||
InactivityScores: inactivityScores,
|
||||
CurrentSyncCommittee: currentSyncCommittee,
|
||||
NextSyncCommittee: nextSyncCommittee,
|
||||
LatestExecutionPayloadHeader: &enginev1.ExecutionPayloadHeaderDeneb{
|
||||
ParentHash: payloadHeader.ParentHash(),
|
||||
FeeRecipient: payloadHeader.FeeRecipient(),
|
||||
StateRoot: payloadHeader.StateRoot(),
|
||||
ReceiptsRoot: payloadHeader.ReceiptsRoot(),
|
||||
LogsBloom: payloadHeader.LogsBloom(),
|
||||
PrevRandao: payloadHeader.PrevRandao(),
|
||||
BlockNumber: payloadHeader.BlockNumber(),
|
||||
GasLimit: payloadHeader.GasLimit(),
|
||||
GasUsed: payloadHeader.GasUsed(),
|
||||
Timestamp: payloadHeader.Timestamp(),
|
||||
ExtraData: payloadHeader.ExtraData(),
|
||||
BaseFeePerGas: payloadHeader.BaseFeePerGas(),
|
||||
BlockHash: payloadHeader.BlockHash(),
|
||||
TransactionsRoot: txRoot,
|
||||
WithdrawalsRoot: wdRoot,
|
||||
ExcessBlobGas: excessBlobGas,
|
||||
BlobGasUsed: blobGasUsed,
|
||||
},
|
||||
NextWithdrawalIndex: wi,
|
||||
NextWithdrawalValidatorIndex: vi,
|
||||
HistoricalSummaries: summaries,
|
||||
|
||||
DepositRequestsStartIndex: params.BeaconConfig().UnsetDepositRequestsStartIndex,
|
||||
DepositBalanceToConsume: 0,
|
||||
EarliestConsolidationEpoch: helpers.ActivationExitEpoch(slots.ToEpoch(beaconState.Slot())),
|
||||
PendingDeposits: make([]*ethpb.PendingDeposit, 0),
|
||||
PendingPartialWithdrawals: make([]*ethpb.PendingPartialWithdrawal, 0),
|
||||
PendingConsolidations: make([]*ethpb.PendingConsolidation, 0),
|
||||
}
|
||||
|
||||
// need to cast the beaconState to use in helper functions
|
||||
post, err := state_native.InitializeFromProtoUnsafeElectra(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to initialize post electra beaconState")
|
||||
}
|
||||
return post, nil
|
||||
}
|
||||
|
||||
// UpgradeToElectra updates inputs a generic state to return the version Electra state.
|
||||
//
|
||||
// nolint:dupword
|
||||
|
||||
@@ -7,6 +7,7 @@ go_library(
|
||||
visibility = [
|
||||
"//beacon-chain:__subpackages__",
|
||||
"//cmd/prysmctl/testnet:__pkg__",
|
||||
"//consensus-types/hdiff:__subpackages__",
|
||||
"//testing/spectest:__subpackages__",
|
||||
"//validator/client:__pkg__",
|
||||
],
|
||||
|
||||
@@ -25,6 +25,9 @@ go_library(
|
||||
"migration_state_validators.go",
|
||||
"schema.go",
|
||||
"state.go",
|
||||
"state_diff.go",
|
||||
"state_diff_cache.go",
|
||||
"state_diff_helpers.go",
|
||||
"state_summary.go",
|
||||
"state_summary_cache.go",
|
||||
"utils.go",
|
||||
@@ -44,6 +47,7 @@ go_library(
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//consensus-types/blocks:go_default_library",
|
||||
"//consensus-types/hdiff:go_default_library",
|
||||
"//consensus-types/interfaces:go_default_library",
|
||||
"//consensus-types/light-client:go_default_library",
|
||||
"//consensus-types/primitives:go_default_library",
|
||||
@@ -51,6 +55,7 @@ go_library(
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz/detect:go_default_library",
|
||||
"//io/file:go_default_library",
|
||||
"//math:go_default_library",
|
||||
"//monitoring/progress:go_default_library",
|
||||
"//monitoring/tracing:go_default_library",
|
||||
"//monitoring/tracing/trace:go_default_library",
|
||||
@@ -94,6 +99,7 @@ go_test(
|
||||
"migration_archived_index_test.go",
|
||||
"migration_block_slot_index_test.go",
|
||||
"migration_state_validators_test.go",
|
||||
"state_diff_test.go",
|
||||
"state_summary_test.go",
|
||||
"state_test.go",
|
||||
"utils_test.go",
|
||||
@@ -116,6 +122,7 @@ go_test(
|
||||
"//consensus-types/light-client:go_default_library",
|
||||
"//consensus-types/primitives:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//math:go_default_library",
|
||||
"//proto/dbval:go_default_library",
|
||||
"//proto/engine/v1:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
|
||||
@@ -91,6 +91,7 @@ type Store struct {
|
||||
blockCache *ristretto.Cache[string, interfaces.ReadOnlySignedBeaconBlock]
|
||||
validatorEntryCache *ristretto.Cache[[]byte, *ethpb.Validator]
|
||||
stateSummaryCache *stateSummaryCache
|
||||
stateDiffCache *stateDiffCache
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@@ -112,6 +113,7 @@ var Buckets = [][]byte{
|
||||
lightClientUpdatesBucket,
|
||||
lightClientBootstrapBucket,
|
||||
lightClientSyncCommitteeBucket,
|
||||
stateDiffBucket,
|
||||
// Indices buckets.
|
||||
blockSlotIndicesBucket,
|
||||
stateSlotIndicesBucket,
|
||||
@@ -182,6 +184,7 @@ func NewKVStore(ctx context.Context, dirPath string, opts ...KVStoreOption) (*St
|
||||
blockCache: blockCache,
|
||||
validatorEntryCache: validatorCache,
|
||||
stateSummaryCache: newStateSummaryCache(),
|
||||
stateDiffCache: nil,
|
||||
ctx: ctx,
|
||||
}
|
||||
for _, o := range opts {
|
||||
@@ -200,6 +203,14 @@ func NewKVStore(ctx context.Context, dirPath string, opts ...KVStoreOption) (*St
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if features.Get().EnableStateDiff {
|
||||
sdCache, err := newStateDiffCache(kv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kv.stateDiffCache = sdCache
|
||||
}
|
||||
|
||||
return kv, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
stateValidatorsBucket = []byte("state-validators")
|
||||
feeRecipientBucket = []byte("fee-recipient")
|
||||
registrationBucket = []byte("registration")
|
||||
stateDiffBucket = []byte("state-diff")
|
||||
|
||||
// Light Client Updates Bucket
|
||||
lightClientUpdatesBucket = []byte("light-client-updates")
|
||||
|
||||
232
beacon-chain/db/kv/state_diff.go
Normal file
232
beacon-chain/db/kv/state_diff.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/beacon-chain/state"
|
||||
"github.com/OffchainLabs/prysm/v6/config/params"
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/hdiff"
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
|
||||
"github.com/OffchainLabs/prysm/v6/monitoring/tracing/trace"
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const (
|
||||
stateSuffix = "_s"
|
||||
validatorSuffix = "_v"
|
||||
balancesSuffix = "_b"
|
||||
)
|
||||
|
||||
/*
|
||||
We use a level-based approach to save state diffs. The levels are 0-6, where each level corresponds to an exponent of 2 (exponents[lvl]).
|
||||
The data at level 0 is saved every 2**exponent[0] slots and always contains a full state snapshot that is used as a base for the delta saved at other levels.
|
||||
*/
|
||||
|
||||
// saveStateByDiff takes a state and decides between saving a full state snapshot or a diff.
|
||||
func (s *Store) saveStateByDiff(ctx context.Context, st state.ReadOnlyBeaconState) error {
|
||||
_, span := trace.StartSpan(ctx, "BeaconDB.saveStateByDiff")
|
||||
defer span.End()
|
||||
|
||||
if st == nil {
|
||||
return errors.New("state is nil")
|
||||
}
|
||||
|
||||
slot := st.Slot()
|
||||
offset := s.getOffset()
|
||||
if uint64(slot) < offset {
|
||||
return ErrSlotBeforeOffset
|
||||
}
|
||||
|
||||
// Find the level to save the state.
|
||||
lvl := computeLevel(offset, slot)
|
||||
if lvl == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save full state if level is 0.
|
||||
if lvl == 0 {
|
||||
return s.saveFullSnapshot(st)
|
||||
}
|
||||
|
||||
// Get anchor state to compute the diff from.
|
||||
anchorState, err := s.getAnchorState(offset, lvl, slot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.saveHdiff(lvl, anchorState, st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stateByDiff retrieves the full state for a given slot.
|
||||
func (s *Store) stateByDiff(ctx context.Context, slot primitives.Slot) (state.BeaconState, error) {
|
||||
offset := s.getOffset()
|
||||
if uint64(slot) < offset {
|
||||
return nil, ErrSlotBeforeOffset
|
||||
}
|
||||
|
||||
snapshot, diffChain, err := s.getBaseAndDiffChain(offset, slot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, diff := range diffChain {
|
||||
snapshot, err = hdiff.ApplyDiff(ctx, snapshot, diff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// SaveHdiff computes the diff between the anchor state and the current state and saves it to the database.
|
||||
func (s *Store) saveHdiff(lvl int, anchor, st state.ReadOnlyBeaconState) error {
|
||||
slot := uint64(st.Slot())
|
||||
key := makeKey(lvl, slot)
|
||||
|
||||
diff, err := hdiff.Diff(anchor, st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.db.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bolt.ErrBucketNotFound
|
||||
}
|
||||
buf := append(key, stateSuffix...)
|
||||
if err := bucket.Put(buf, diff.StateDiff); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = append(key, validatorSuffix...)
|
||||
if err := bucket.Put(buf, diff.ValidatorDiffs); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = append(key, balancesSuffix...)
|
||||
if err := bucket.Put(buf, diff.BalancesDiff); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save the full state to the cache (if not the last level).
|
||||
if lvl != len(params.StateHierarchyExponents())-1 {
|
||||
err = s.stateDiffCache.setAnchor(lvl, st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveFullSnapshot saves the full level 0 state snapshot to the database.
|
||||
func (s *Store) saveFullSnapshot(st state.ReadOnlyBeaconState) error {
|
||||
slot := uint64(st.Slot())
|
||||
key := makeKey(0, slot)
|
||||
stateBytes, err := st.MarshalSSZ()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// add version key to value
|
||||
enc, err := addKey(st.Version(), stateBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.db.Update(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bolt.ErrBucketNotFound
|
||||
}
|
||||
|
||||
if err := bucket.Put(key, enc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Save the full state to the cache, and invalidate other levels.
|
||||
s.stateDiffCache.clearAnchors()
|
||||
err = s.stateDiffCache.setAnchor(0, st)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) getDiff(lvl int, slot uint64) (hdiff.HdiffBytes, error) {
|
||||
key := makeKey(lvl, slot)
|
||||
var stateDiff []byte
|
||||
var validatorDiff []byte
|
||||
var balancesDiff []byte
|
||||
|
||||
err := s.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bolt.ErrBucketNotFound
|
||||
}
|
||||
buf := append(key, stateSuffix...)
|
||||
stateDiff = bucket.Get(buf)
|
||||
if stateDiff == nil {
|
||||
return errors.New("state diff not found")
|
||||
}
|
||||
buf = append(key, validatorSuffix...)
|
||||
validatorDiff = bucket.Get(buf)
|
||||
if validatorDiff == nil {
|
||||
return errors.New("validator diff not found")
|
||||
}
|
||||
buf = append(key, balancesSuffix...)
|
||||
balancesDiff = bucket.Get(buf)
|
||||
if balancesDiff == nil {
|
||||
return errors.New("balances diff not found")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return hdiff.HdiffBytes{}, err
|
||||
}
|
||||
|
||||
return hdiff.HdiffBytes{
|
||||
StateDiff: stateDiff,
|
||||
ValidatorDiffs: validatorDiff,
|
||||
BalancesDiff: balancesDiff,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) getFullSnapshot(slot uint64) (state.BeaconState, error) {
|
||||
key := makeKey(0, slot)
|
||||
var enc []byte
|
||||
|
||||
err := s.db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bolt.ErrBucketNotFound
|
||||
}
|
||||
enc = bucket.Get(key)
|
||||
if enc == nil {
|
||||
return errors.New("state not found")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.decodeStateSnapshot(enc)
|
||||
}
|
||||
77
beacon-chain/db/kv/state_diff_cache.go
Normal file
77
beacon-chain/db/kv/state_diff_cache.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/beacon-chain/state"
|
||||
"github.com/OffchainLabs/prysm/v6/config/params"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
type stateDiffCache struct {
|
||||
sync.RWMutex
|
||||
anchors []state.ReadOnlyBeaconState
|
||||
offset uint64
|
||||
}
|
||||
|
||||
func newStateDiffCache(s *Store) (*stateDiffCache, error) {
|
||||
var offset uint64
|
||||
|
||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
|
||||
offsetBytes := bucket.Get([]byte("offset"))
|
||||
if offsetBytes == nil {
|
||||
return errors.New("state diff cache: offset not found")
|
||||
}
|
||||
offset = binary.LittleEndian.Uint64(offsetBytes)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stateDiffCache{
|
||||
anchors: make([]state.ReadOnlyBeaconState, len(params.StateHierarchyExponents())),
|
||||
offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *stateDiffCache) getAnchor(level int) state.ReadOnlyBeaconState {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return c.anchors[level]
|
||||
}
|
||||
|
||||
func (c *stateDiffCache) setAnchor(level int, anchor state.ReadOnlyBeaconState) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if level >= len(c.anchors) || level < 0 {
|
||||
return errors.New("state diff cache: anchor level out of range")
|
||||
}
|
||||
c.anchors[level] = anchor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stateDiffCache) getOffset() uint64 {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return c.offset
|
||||
}
|
||||
|
||||
func (c *stateDiffCache) setOffset(offset uint64) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.offset = offset
|
||||
}
|
||||
|
||||
func (c *stateDiffCache) clearAnchors() {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.anchors = make([]state.ReadOnlyBeaconState, len(params.StateHierarchyExponents()))
|
||||
}
|
||||
234
beacon-chain/db/kv/state_diff_helpers.go
Normal file
234
beacon-chain/db/kv/state_diff_helpers.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/beacon-chain/state"
|
||||
state_native "github.com/OffchainLabs/prysm/v6/beacon-chain/state/state-native"
|
||||
"github.com/OffchainLabs/prysm/v6/config/params"
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/hdiff"
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
|
||||
"github.com/OffchainLabs/prysm/v6/math"
|
||||
ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1"
|
||||
"github.com/OffchainLabs/prysm/v6/runtime/version"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
var (
|
||||
offsetKey = []byte("offset")
|
||||
ErrSlotBeforeOffset = errors.New("slot is before root offset")
|
||||
)
|
||||
|
||||
func makeKey(level int, slot uint64) []byte {
|
||||
buf := make([]byte, 16)
|
||||
buf[0] = byte(level)
|
||||
binary.LittleEndian.PutUint64(buf[1:], slot)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (s *Store) getAnchorState(offset uint64, lvl int, slot primitives.Slot) (anchor state.ReadOnlyBeaconState, err error) {
|
||||
if lvl <= 0 || lvl >= len(params.StateHierarchyExponents()) {
|
||||
return nil, errors.New("invalid value for level")
|
||||
}
|
||||
|
||||
relSlot := uint64(slot) - offset
|
||||
prevExp := params.StateHierarchyExponents()[lvl-1]
|
||||
span := math.PowerOf2(prevExp)
|
||||
anchorSlot := primitives.Slot((relSlot / span * span) + offset)
|
||||
|
||||
// anchorLvl can be [0, lvl-1]
|
||||
anchorLvl := computeLevel(offset, anchorSlot)
|
||||
if anchorLvl == -1 {
|
||||
return nil, errors.New("could not compute anchor level")
|
||||
}
|
||||
|
||||
// Check if we have the anchor in cache.
|
||||
anchor = s.stateDiffCache.getAnchor(anchorLvl)
|
||||
if anchor != nil {
|
||||
return anchor, nil
|
||||
}
|
||||
|
||||
// If not, load it from the database.
|
||||
anchor, err = s.stateByDiff(context.Background(), anchorSlot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Save it in the cache.
|
||||
err = s.stateDiffCache.setAnchor(anchorLvl, anchor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return anchor, nil
|
||||
}
|
||||
|
||||
// computeLevel computes the level in the diff tree. Returns -1 in case slot should not be in tree.
|
||||
func computeLevel(offset uint64, slot primitives.Slot) int {
|
||||
rel := uint64(slot) - offset
|
||||
for i, exp := range params.StateHierarchyExponents() {
|
||||
span := math.PowerOf2(exp)
|
||||
if rel%span == 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
// If rel isn’t on any of the boundaries, we should ignore saving it.
|
||||
return -1
|
||||
}
|
||||
|
||||
func (s *Store) setOffset(slot primitives.Slot) error {
|
||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
|
||||
offsetBytes := bucket.Get(offsetKey)
|
||||
if offsetBytes != nil {
|
||||
return fmt.Errorf("offset already set to %d", binary.LittleEndian.Uint64(offsetBytes))
|
||||
}
|
||||
|
||||
offsetBytes = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(offsetBytes, uint64(slot))
|
||||
if err := bucket.Put(offsetKey, offsetBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Save the offset in the cache.
|
||||
s.stateDiffCache.setOffset(uint64(slot))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) getOffset() uint64 {
|
||||
return s.stateDiffCache.getOffset()
|
||||
}
|
||||
|
||||
func keyForSnapshot(v int) []byte {
|
||||
switch v {
|
||||
case version.Fulu:
|
||||
return fuluKey
|
||||
case version.Electra:
|
||||
return ElectraKey
|
||||
case version.Deneb:
|
||||
return denebKey
|
||||
case version.Capella:
|
||||
return capellaKey
|
||||
case version.Bellatrix:
|
||||
return bellatrixKey
|
||||
case version.Altair:
|
||||
return altairKey
|
||||
default:
|
||||
// Phase0
|
||||
return []byte{}
|
||||
}
|
||||
}
|
||||
|
||||
func addKey(v int, bytes []byte) ([]byte, error) {
|
||||
key := keyForSnapshot(v)
|
||||
enc := make([]byte, len(key)+len(bytes))
|
||||
copy(enc, key)
|
||||
copy(enc[len(key):], bytes)
|
||||
return enc, nil
|
||||
}
|
||||
|
||||
func (s *Store) decodeStateSnapshot(enc []byte) (state.BeaconState, error) {
|
||||
switch {
|
||||
case hasFuluKey(enc):
|
||||
var fuluState ethpb.BeaconStateFulu
|
||||
if err := fuluState.UnmarshalSSZ(enc[len(ElectraKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeFulu(&fuluState)
|
||||
case HasElectraKey(enc):
|
||||
var electraState ethpb.BeaconStateElectra
|
||||
if err := electraState.UnmarshalSSZ(enc[len(ElectraKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeElectra(&electraState)
|
||||
case hasDenebKey(enc):
|
||||
var denebState ethpb.BeaconStateDeneb
|
||||
if err := denebState.UnmarshalSSZ(enc[len(denebKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeDeneb(&denebState)
|
||||
case hasCapellaKey(enc):
|
||||
var capellaState ethpb.BeaconStateCapella
|
||||
if err := capellaState.UnmarshalSSZ(enc[len(capellaKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeCapella(&capellaState)
|
||||
case hasBellatrixKey(enc):
|
||||
var bellatrixState ethpb.BeaconStateBellatrix
|
||||
if err := bellatrixState.UnmarshalSSZ(enc[len(bellatrixKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeBellatrix(&bellatrixState)
|
||||
case hasAltairKey(enc):
|
||||
var altairState ethpb.BeaconStateAltair
|
||||
if err := altairState.UnmarshalSSZ(enc[len(altairKey):]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafeAltair(&altairState)
|
||||
default:
|
||||
var phase0State ethpb.BeaconState
|
||||
if err := phase0State.UnmarshalSSZ(enc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return state_native.InitializeFromProtoUnsafePhase0(&phase0State)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) getBaseAndDiffChain(offset uint64, slot primitives.Slot) (state.BeaconState, []hdiff.HdiffBytes, error) {
|
||||
rel := uint64(slot) - offset
|
||||
lvl := computeLevel(offset, slot)
|
||||
if lvl == -1 {
|
||||
return nil, nil, errors.New("slot not in tree")
|
||||
}
|
||||
|
||||
exponents := params.StateHierarchyExponents()
|
||||
|
||||
baseSpan := math.PowerOf2(exponents[0])
|
||||
baseAnchorSlot := (rel / baseSpan * baseSpan) + offset
|
||||
|
||||
var diffChainIndices []uint64
|
||||
for i := 1; i <= lvl; i++ {
|
||||
span := math.PowerOf2(exponents[i])
|
||||
diffSlot := rel / span * span
|
||||
if diffSlot == baseAnchorSlot {
|
||||
continue
|
||||
}
|
||||
diffChainIndices = appendUnique(diffChainIndices, diffSlot+offset)
|
||||
}
|
||||
|
||||
baseSnapshot, err := s.getFullSnapshot(baseAnchorSlot)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
diffChain := make([]hdiff.HdiffBytes, 0, len(diffChainIndices))
|
||||
for _, diffSlot := range diffChainIndices {
|
||||
diff, err := s.getDiff(computeLevel(offset, primitives.Slot(diffSlot)), diffSlot)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
diffChain = append(diffChain, diff)
|
||||
}
|
||||
|
||||
return baseSnapshot, diffChain, nil
|
||||
}
|
||||
|
||||
func appendUnique(s []uint64, v uint64) []uint64 {
|
||||
for _, x := range s {
|
||||
if x == v {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return append(s, v)
|
||||
}
|
||||
577
beacon-chain/db/kv/state_diff_test.go
Normal file
577
beacon-chain/db/kv/state_diff_test.go
Normal file
@@ -0,0 +1,577 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/beacon-chain/state"
|
||||
"github.com/OffchainLabs/prysm/v6/config/params"
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
|
||||
"github.com/OffchainLabs/prysm/v6/math"
|
||||
ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1"
|
||||
"github.com/OffchainLabs/prysm/v6/runtime/version"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/require"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/util"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func TestStateDiff_LoadOrInitOffset(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
err := setOffsetInDB(db, 10)
|
||||
require.NoError(t, err)
|
||||
offset := db.getOffset()
|
||||
require.Equal(t, uint64(10), offset)
|
||||
|
||||
err = db.setOffset(10)
|
||||
require.ErrorContains(t, "offset already set", err)
|
||||
offset = db.getOffset()
|
||||
require.Equal(t, uint64(10), offset)
|
||||
}
|
||||
|
||||
func TestStateDiff_ComputeLevel(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
offset := db.getOffset()
|
||||
|
||||
// 2 ** 21
|
||||
lvl := computeLevel(offset, primitives.Slot(math.PowerOf2(21)))
|
||||
require.Equal(t, 0, lvl)
|
||||
|
||||
// 2 ** 21 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(21)*3))
|
||||
require.Equal(t, 0, lvl)
|
||||
|
||||
// 2 ** 18
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(18)))
|
||||
require.Equal(t, 1, lvl)
|
||||
|
||||
// 2 ** 18 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(18)*3))
|
||||
require.Equal(t, 1, lvl)
|
||||
|
||||
// 2 ** 16
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(16)))
|
||||
require.Equal(t, 2, lvl)
|
||||
|
||||
// 2 ** 16 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(16)*3))
|
||||
require.Equal(t, 2, lvl)
|
||||
|
||||
// 2 ** 13
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(13)))
|
||||
require.Equal(t, 3, lvl)
|
||||
|
||||
// 2 ** 13 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(13)*3))
|
||||
require.Equal(t, 3, lvl)
|
||||
|
||||
// 2 ** 11
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(11)))
|
||||
require.Equal(t, 4, lvl)
|
||||
|
||||
// 2 ** 11 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(11)*3))
|
||||
require.Equal(t, 4, lvl)
|
||||
|
||||
// 2 ** 9
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(9)))
|
||||
require.Equal(t, 5, lvl)
|
||||
|
||||
// 2 ** 9 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(9)*3))
|
||||
require.Equal(t, 5, lvl)
|
||||
|
||||
// 2 ** 5
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)))
|
||||
require.Equal(t, 6, lvl)
|
||||
|
||||
// 2 ** 5 * 3
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)*3))
|
||||
require.Equal(t, 6, lvl)
|
||||
|
||||
// 2 ** 7
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(7)))
|
||||
require.Equal(t, 6, lvl)
|
||||
|
||||
// 2 ** 5 + 1
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+1))
|
||||
require.Equal(t, -1, lvl)
|
||||
|
||||
// 2 ** 5 + 16
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+16))
|
||||
require.Equal(t, -1, lvl)
|
||||
|
||||
// 2 ** 5 + 32
|
||||
lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+32))
|
||||
require.Equal(t, 6, lvl)
|
||||
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveFullSnapshot(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
// Create state with slot 0
|
||||
st, enc := createState(t, 0, v)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
s := bucket.Get(makeKey(0, uint64(0)))
|
||||
if s == nil {
|
||||
return bbolt.ErrIncompatibleValue
|
||||
}
|
||||
require.DeepSSZEqual(t, enc, s)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.stateByDiff(context.Background(), 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveDiff(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
// Create state with slot 2**21
|
||||
slot := primitives.Slot(math.PowerOf2(21))
|
||||
st, enc := createState(t, slot, v)
|
||||
|
||||
err := setOffsetInDB(db, uint64(slot))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
s := bucket.Get(makeKey(0, uint64(slot)))
|
||||
if s == nil {
|
||||
return bbolt.ErrIncompatibleValue
|
||||
}
|
||||
require.DeepSSZEqual(t, enc, s)
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create state with slot 2**18 (+2**21)
|
||||
slot = primitives.Slot(math.PowerOf2(18) + math.PowerOf2(21))
|
||||
st, _ = createState(t, slot, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
key := makeKey(1, uint64(slot))
|
||||
err = db.db.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
buf := append(key, "_s"...)
|
||||
s := bucket.Get(buf)
|
||||
if s == nil {
|
||||
return bbolt.ErrIncompatibleValue
|
||||
}
|
||||
buf = append(key, "_v"...)
|
||||
v := bucket.Get(buf)
|
||||
if v == nil {
|
||||
return bbolt.ErrIncompatibleValue
|
||||
}
|
||||
buf = append(key, "_b"...)
|
||||
b := bucket.Get(buf)
|
||||
if b == nil {
|
||||
return bbolt.ErrIncompatibleValue
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiff(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot := primitives.Slot(math.PowerOf2(5))
|
||||
st, _ = createState(t, slot, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.stateByDiff(context.Background(), slot)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot := primitives.Slot(math.PowerOf2(11))
|
||||
st, _ = createState(t, slot, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.stateByDiff(context.Background(), slot)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
|
||||
slot = primitives.Slot(math.PowerOf2(11) + math.PowerOf2(9))
|
||||
st, _ = createState(t, slot, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err = db.stateByDiff(context.Background(), slot)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err = st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err = readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
|
||||
slot = primitives.Slot(math.PowerOf2(11) + math.PowerOf2(9) + math.PowerOf2(5))
|
||||
st, _ = createState(t, slot, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err = db.stateByDiff(context.Background(), slot)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err = st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err = readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 5; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot := primitives.Slot(math.PowerOf2(5))
|
||||
st, _ = createState(t, slot, v+1)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.stateByDiff(context.Background(), slot)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_OffsetCache(t *testing.T) {
|
||||
// test for slot numbers 0 and 1 for every version
|
||||
for slotNum := 0; slotNum < 2; slotNum++ {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(fmt.Sprintf("slotNum=%d,%s", slotNum, version.String(v)), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
|
||||
slot := primitives.Slot(slotNum)
|
||||
err := setOffsetInDB(db, uint64(slot))
|
||||
require.NoError(t, err)
|
||||
st, _ := createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
offset := db.stateDiffCache.getOffset()
|
||||
require.Equal(t, uint64(slotNum), offset)
|
||||
|
||||
slot2 := primitives.Slot(uint64(slotNum) + math.PowerOf2(params.StateHierarchyExponents()[0]))
|
||||
st2, _ := createState(t, slot2, v)
|
||||
err = db.saveStateByDiff(context.Background(), st2)
|
||||
require.NoError(t, err)
|
||||
|
||||
offset = db.stateDiffCache.getOffset()
|
||||
require.Equal(t, uint64(slot), offset)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateDiff_AnchorCache(t *testing.T) {
|
||||
// test for every version
|
||||
for v := 0; v < 6; v++ {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
exponents := params.StateHierarchyExponents()
|
||||
localCache := make([]state.ReadOnlyBeaconState, len(exponents)-1)
|
||||
db := setupDB(t)
|
||||
err := setOffsetInDB(db, 0) // lvl 0
|
||||
require.NoError(t, err)
|
||||
|
||||
// at first the cache should be empty
|
||||
for i := 0; i < len(params.StateHierarchyExponents()); i++ {
|
||||
anchor := db.stateDiffCache.getAnchor(i)
|
||||
require.IsNil(t, anchor)
|
||||
}
|
||||
|
||||
// add level 0
|
||||
slot := primitives.Slot(0) // offset 0 is already set
|
||||
st, _ := createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
localCache[0] = st
|
||||
|
||||
// level 0 should be the same
|
||||
require.DeepEqual(t, localCache[0], db.stateDiffCache.getAnchor(0))
|
||||
|
||||
// rest of the cache should be nil
|
||||
for i := 1; i < len(exponents)-1; i++ {
|
||||
require.IsNil(t, db.stateDiffCache.getAnchor(i))
|
||||
}
|
||||
|
||||
// skip last level as it does not get cached
|
||||
for i := len(exponents) - 2; i > 0; i-- {
|
||||
slot = primitives.Slot(math.PowerOf2(exponents[i]))
|
||||
st, _ := createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
localCache[i] = st
|
||||
|
||||
// anchor cache must match local cache
|
||||
for i := 0; i < len(exponents)-1; i++ {
|
||||
if localCache[i] == nil {
|
||||
require.IsNil(t, db.stateDiffCache.getAnchor(i))
|
||||
continue
|
||||
}
|
||||
localSSZ, err := localCache[i].MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
anchorSSZ, err := db.stateDiffCache.getAnchor(i).MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, localSSZ, anchorSSZ)
|
||||
}
|
||||
}
|
||||
|
||||
// moving to a new tree should invalidate the cache except for level 0
|
||||
twoTo21 := math.PowerOf2(21)
|
||||
slot = primitives.Slot(twoTo21)
|
||||
st, _ = createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
localCache = make([]state.ReadOnlyBeaconState, len(exponents)-1)
|
||||
localCache[0] = st
|
||||
|
||||
// level 0 should be the same
|
||||
require.DeepEqual(t, localCache[0], db.stateDiffCache.getAnchor(0))
|
||||
|
||||
// rest of the cache should be nil
|
||||
for i := 1; i < len(exponents)-1; i++ {
|
||||
require.IsNil(t, db.stateDiffCache.getAnchor(i))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createState(t *testing.T, slot primitives.Slot, v int) (state.ReadOnlyBeaconState, []byte) {
|
||||
p := params.BeaconConfig()
|
||||
var st state.BeaconState
|
||||
var err error
|
||||
switch v {
|
||||
case version.Altair:
|
||||
st, err = util.NewBeaconStateAltair()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.GenesisForkVersion,
|
||||
CurrentVersion: p.AltairForkVersion,
|
||||
Epoch: p.AltairForkEpoch,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
case version.Bellatrix:
|
||||
st, err = util.NewBeaconStateBellatrix()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.AltairForkVersion,
|
||||
CurrentVersion: p.BellatrixForkVersion,
|
||||
Epoch: p.BellatrixForkEpoch,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
case version.Capella:
|
||||
st, err = util.NewBeaconStateCapella()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.BellatrixForkVersion,
|
||||
CurrentVersion: p.CapellaForkVersion,
|
||||
Epoch: p.CapellaForkEpoch,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
case version.Deneb:
|
||||
st, err = util.NewBeaconStateDeneb()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.CapellaForkVersion,
|
||||
CurrentVersion: p.DenebForkVersion,
|
||||
Epoch: p.DenebForkEpoch,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
case version.Electra:
|
||||
st, err = util.NewBeaconStateElectra()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.DenebForkVersion,
|
||||
CurrentVersion: p.ElectraForkVersion,
|
||||
Epoch: p.ElectraForkEpoch,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
st, err = util.NewBeaconState()
|
||||
require.NoError(t, err)
|
||||
err = st.SetFork(ðpb.Fork{
|
||||
PreviousVersion: p.GenesisForkVersion,
|
||||
CurrentVersion: p.GenesisForkVersion,
|
||||
Epoch: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = st.SetSlot(slot)
|
||||
require.NoError(t, err)
|
||||
slashings := make([]uint64, 8192)
|
||||
slashings[0] = uint64(rand.Intn(10))
|
||||
err = st.SetSlashings(slashings)
|
||||
require.NoError(t, err)
|
||||
stssz, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
enc, err := addKey(v, stssz)
|
||||
require.NoError(t, err)
|
||||
return st, enc
|
||||
}
|
||||
|
||||
func setOffsetInDB(s *Store, offset uint64) error {
|
||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket(stateDiffBucket)
|
||||
if bucket == nil {
|
||||
return bbolt.ErrBucketNotFound
|
||||
}
|
||||
|
||||
offsetBytes := bucket.Get(offsetKey)
|
||||
if offsetBytes != nil {
|
||||
return fmt.Errorf("offset already set to %d", binary.LittleEndian.Uint64(offsetBytes))
|
||||
}
|
||||
|
||||
offsetBytes = make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(offsetBytes, offset)
|
||||
if err := bucket.Put(offsetKey, offsetBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sdCache, err := newStateDiffCache(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.stateDiffCache = sdCache
|
||||
return nil
|
||||
}
|
||||
@@ -264,6 +264,8 @@ type WriteOnlyEth1Data interface {
|
||||
AppendEth1DataVotes(val *ethpb.Eth1Data) error
|
||||
SetEth1DepositIndex(val uint64) error
|
||||
ExitEpochAndUpdateChurn(exitBalance primitives.Gwei) (primitives.Epoch, error)
|
||||
SetExitBalanceToConsume(val primitives.Gwei) error
|
||||
SetEarliestExitEpoch(val primitives.Epoch) error
|
||||
}
|
||||
|
||||
// WriteOnlyValidators defines a struct which only has write access to validators methods.
|
||||
@@ -331,6 +333,7 @@ type WriteOnlyWithdrawals interface {
|
||||
DequeuePendingPartialWithdrawals(num uint64) error
|
||||
SetNextWithdrawalIndex(i uint64) error
|
||||
SetNextWithdrawalValidatorIndex(i primitives.ValidatorIndex) error
|
||||
SetPendingPartialWithdrawals(val []*ethpb.PendingPartialWithdrawal) error
|
||||
}
|
||||
|
||||
type WriteOnlyConsolidations interface {
|
||||
|
||||
@@ -75,3 +75,33 @@ func (b *BeaconState) ExitEpochAndUpdateChurn(exitBalance primitives.Gwei) (prim
|
||||
|
||||
return b.earliestExitEpoch, nil
|
||||
}
|
||||
|
||||
// SetExitBalanceToConsume sets the exit balance to consume. This method mutates the state.
|
||||
func (b *BeaconState) SetExitBalanceToConsume(exitBalanceToConsume primitives.Gwei) error {
|
||||
if b.version < version.Electra {
|
||||
return errNotSupported("SetExitBalanceToConsume", b.version)
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
b.exitBalanceToConsume = exitBalanceToConsume
|
||||
b.markFieldAsDirty(types.ExitBalanceToConsume)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetEarliestExitEpoch sets the earliest exit epoch. This method mutates the state.
|
||||
func (b *BeaconState) SetEarliestExitEpoch(earliestExitEpoch primitives.Epoch) error {
|
||||
if b.version < version.Electra {
|
||||
return errNotSupported("SetEarliestExitEpoch", b.version)
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
b.earliestExitEpoch = earliestExitEpoch
|
||||
b.markFieldAsDirty(types.EarliestExitEpoch)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -100,3 +100,22 @@ func (b *BeaconState) DequeuePendingPartialWithdrawals(n uint64) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPendingPartialWithdrawals sets the pending partial withdrawals. This method mutates the state.
|
||||
func (b *BeaconState) SetPendingPartialWithdrawals(pendingPartialWithdrawals []*eth.PendingPartialWithdrawal) error {
|
||||
if b.version < version.Electra {
|
||||
return errNotSupported("SetPendingPartialWithdrawals", b.version)
|
||||
}
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if pendingPartialWithdrawals == nil {
|
||||
return errors.New("cannot set nil pending partial withdrawals")
|
||||
}
|
||||
|
||||
b.pendingPartialWithdrawals = pendingPartialWithdrawals
|
||||
b.markFieldAsDirty(types.PendingPartialWithdrawals)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
3
changelog/potuz_hdiff_diff_type.md
Normal file
3
changelog/potuz_hdiff_diff_type.md
Normal file
@@ -0,0 +1,3 @@
|
||||
### Added
|
||||
|
||||
- Add native state diff type and marshalling functions
|
||||
@@ -53,6 +53,7 @@ type Flags struct {
|
||||
EnableDutiesV2 bool // EnableDutiesV2 sets validator client to use the get Duties V2 endpoint
|
||||
EnableWeb bool // EnableWeb enables the webui on the validator client
|
||||
SSZOnly bool // SSZOnly forces the validator client to use SSZ for communication with the beacon node when REST mode is enabled (useful for debugging)
|
||||
EnableStateDiff bool // EnableStateDiff enables the experimental state diff feature for the beacon node.
|
||||
// Logging related toggles.
|
||||
DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected.
|
||||
EnableFullSSZDataLogging bool // Enables logging for full ssz data on rejected gossip messages
|
||||
@@ -277,6 +278,10 @@ func ConfigureBeaconChain(ctx *cli.Context) error {
|
||||
logEnabled(enableExperimentalAttestationPool)
|
||||
cfg.EnableExperimentalAttestationPool = true
|
||||
}
|
||||
if ctx.IsSet(enableStateDiff.Name) {
|
||||
logEnabled(enableStateDiff)
|
||||
cfg.EnableStateDiff = true
|
||||
}
|
||||
if ctx.IsSet(forceHeadFlag.Name) {
|
||||
logEnabled(forceHeadFlag)
|
||||
cfg.ForceHead = ctx.String(forceHeadFlag.Name)
|
||||
|
||||
@@ -176,6 +176,10 @@ var (
|
||||
Name: "enable-experimental-attestation-pool",
|
||||
Usage: "Enables an experimental attestation pool design.",
|
||||
}
|
||||
enableStateDiff = &cli.BoolFlag{
|
||||
Name: "enable-state-diff",
|
||||
Usage: "Enables the experimental state diff feature.",
|
||||
}
|
||||
// forceHeadFlag is a flag to force the head of the beacon chain to a specific block.
|
||||
forceHeadFlag = &cli.StringFlag{
|
||||
Name: "sync-from",
|
||||
@@ -267,6 +271,7 @@ var BeaconChainFlags = combinedFlags([]cli.Flag{
|
||||
DisableQUIC,
|
||||
EnableDiscoveryReboot,
|
||||
enableExperimentalAttestationPool,
|
||||
enableStateDiff,
|
||||
forceHeadFlag,
|
||||
blacklistRoots,
|
||||
}, deprecatedBeaconFlags, deprecatedFlags, upcomingDeprecation)
|
||||
|
||||
@@ -14,6 +14,7 @@ go_library(
|
||||
"mainnet_config.go",
|
||||
"minimal_config.go",
|
||||
"network_config.go",
|
||||
"state_diff_config.go",
|
||||
"testnet_e2e_config.go",
|
||||
"testnet_holesky_config.go",
|
||||
"testnet_hoodi_config.go",
|
||||
|
||||
9
config/params/state_diff_config.go
Normal file
9
config/params/state_diff_config.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package params
|
||||
|
||||
var (
|
||||
stateHierarchyExponents = []uint64{21, 18, 16, 13, 11, 9, 5}
|
||||
)
|
||||
|
||||
func StateHierarchyExponents() []uint64 {
|
||||
return stateHierarchyExponents
|
||||
}
|
||||
@@ -40,6 +40,12 @@ func NewWrappedExecutionData(v proto.Message) (interfaces.ExecutionData, error)
|
||||
case *enginev1.ExecutionBundleElectra:
|
||||
// note: no payload changes in electra so using deneb
|
||||
return WrappedExecutionPayloadDeneb(pbStruct.Payload)
|
||||
case *enginev1.ExecutionPayloadHeader:
|
||||
return WrappedExecutionPayloadHeader(pbStruct)
|
||||
case *enginev1.ExecutionPayloadHeaderCapella:
|
||||
return WrappedExecutionPayloadHeaderCapella(pbStruct)
|
||||
case *enginev1.ExecutionPayloadHeaderDeneb:
|
||||
return WrappedExecutionPayloadHeaderDeneb(pbStruct)
|
||||
default:
|
||||
return nil, ErrUnsupportedVersion
|
||||
}
|
||||
|
||||
53
consensus-types/hdiff/BUILD.bazel
Normal file
53
consensus-types/hdiff/BUILD.bazel
Normal file
@@ -0,0 +1,53 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["state_diff.go"],
|
||||
importpath = "github.com/OffchainLabs/prysm/v6/consensus-types/hdiff",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//beacon-chain/core/altair:go_default_library",
|
||||
"//beacon-chain/core/capella:go_default_library",
|
||||
"//beacon-chain/core/deneb:go_default_library",
|
||||
"//beacon-chain/core/electra:go_default_library",
|
||||
"//beacon-chain/core/execution:go_default_library",
|
||||
"//beacon-chain/state:go_default_library",
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//consensus-types/blocks:go_default_library",
|
||||
"//consensus-types/helpers:go_default_library",
|
||||
"//consensus-types/interfaces:go_default_library",
|
||||
"//consensus-types/primitives:go_default_library",
|
||||
"//proto/engine/v1:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//runtime/version:go_default_library",
|
||||
"@com_github_golang_snappy//:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_prysmaticlabs_fastssz//:go_default_library",
|
||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//:go_default_library",
|
||||
"@org_golang_google_protobuf//proto:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = [
|
||||
"fuzz_test.go",
|
||||
"property_test.go",
|
||||
"security_test.go",
|
||||
"state_diff_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
"//beacon-chain/core/transition:go_default_library",
|
||||
"//beacon-chain/state:go_default_library",
|
||||
"//beacon-chain/state/state-native:go_default_library",
|
||||
"//consensus-types/blocks:go_default_library",
|
||||
"//consensus-types/primitives:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//testing/require:go_default_library",
|
||||
"//testing/util:go_default_library",
|
||||
"@com_github_golang_snappy//:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
],
|
||||
)
|
||||
491
consensus-types/hdiff/fuzz_test.go
Normal file
491
consensus-types/hdiff/fuzz_test.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package hdiff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/util"
|
||||
)
|
||||
|
||||
// FuzzNewHdiff tests parsing variations of realistic diffs
|
||||
func FuzzNewHdiff(f *testing.F) {
|
||||
// Add seed corpus with various valid diffs from realistic scenarios
|
||||
sizes := []uint64{8, 16, 32}
|
||||
for _, size := range sizes {
|
||||
source, _ := util.DeterministicGenesisStateElectra(f, size)
|
||||
|
||||
// Create various realistic target states
|
||||
scenarios := []string{"slot_change", "balance_change", "validator_change", "multiple_changes"}
|
||||
for _, scenario := range scenarios {
|
||||
target := source.Copy()
|
||||
|
||||
switch scenario {
|
||||
case "slot_change":
|
||||
_ = target.SetSlot(source.Slot() + 1)
|
||||
case "balance_change":
|
||||
balances := target.Balances()
|
||||
if len(balances) > 0 {
|
||||
balances[0] += 1000000000
|
||||
_ = target.SetBalances(balances)
|
||||
}
|
||||
case "validator_change":
|
||||
validators := target.Validators()
|
||||
if len(validators) > 0 {
|
||||
validators[0].EffectiveBalance += 1000000000
|
||||
_ = target.SetValidators(validators)
|
||||
}
|
||||
case "multiple_changes":
|
||||
_ = target.SetSlot(source.Slot() + 5)
|
||||
balances := target.Balances()
|
||||
validators := target.Validators()
|
||||
if len(balances) > 0 && len(validators) > 0 {
|
||||
balances[0] += 2000000000
|
||||
validators[0].EffectiveBalance += 1000000000
|
||||
_ = target.SetBalances(balances)
|
||||
_ = target.SetValidators(validators)
|
||||
}
|
||||
}
|
||||
|
||||
validDiff, err := Diff(source, target)
|
||||
if err == nil {
|
||||
f.Add(validDiff.StateDiff, validDiff.ValidatorDiffs, validDiff.BalancesDiff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, stateDiff, validatorDiffs, balancesDiff []byte) {
|
||||
// Limit input sizes to reasonable bounds
|
||||
if len(stateDiff) > 5000 || len(validatorDiffs) > 5000 || len(balancesDiff) > 5000 {
|
||||
return
|
||||
}
|
||||
|
||||
input := HdiffBytes{
|
||||
StateDiff: stateDiff,
|
||||
ValidatorDiffs: validatorDiffs,
|
||||
BalancesDiff: balancesDiff,
|
||||
}
|
||||
|
||||
// Test parsing - should not panic even with corrupted but bounded data
|
||||
_, err := newHdiff(input)
|
||||
_ = err // Expected to fail with corrupted data
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzNewStateDiff tests the newStateDiff function with random compressed input
|
||||
func FuzzNewStateDiff(f *testing.F) {
|
||||
// Add seed corpus
|
||||
source, _ := util.DeterministicGenesisStateElectra(f, 16)
|
||||
target := source.Copy()
|
||||
_ = target.SetSlot(source.Slot() + 5)
|
||||
|
||||
diff, err := diffToState(source, target)
|
||||
if err == nil {
|
||||
serialized := diff.serialize()
|
||||
f.Add(serialized)
|
||||
}
|
||||
|
||||
// Add edge cases
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0x01})
|
||||
f.Add([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("newStateDiff panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Should never panic, only return error
|
||||
_, err := newStateDiff(data)
|
||||
_ = err
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzNewValidatorDiffs tests validator diff deserialization
|
||||
func FuzzNewValidatorDiffs(f *testing.F) {
|
||||
// Add seed corpus
|
||||
source, _ := util.DeterministicGenesisStateElectra(f, 8)
|
||||
target := source.Copy()
|
||||
vals := target.Validators()
|
||||
if len(vals) > 0 {
|
||||
modifiedVal := ðpb.Validator{
|
||||
PublicKey: vals[0].PublicKey,
|
||||
WithdrawalCredentials: vals[0].WithdrawalCredentials,
|
||||
EffectiveBalance: vals[0].EffectiveBalance + 1000,
|
||||
Slashed: !vals[0].Slashed,
|
||||
ActivationEligibilityEpoch: vals[0].ActivationEligibilityEpoch,
|
||||
ActivationEpoch: vals[0].ActivationEpoch,
|
||||
ExitEpoch: vals[0].ExitEpoch,
|
||||
WithdrawableEpoch: vals[0].WithdrawableEpoch,
|
||||
}
|
||||
vals[0] = modifiedVal
|
||||
_ = target.SetValidators(vals)
|
||||
|
||||
// Create a simple diff for fuzzing - we'll just use raw bytes
|
||||
_, err := diffToVals(source, target)
|
||||
if err == nil {
|
||||
// Add some realistic validator diff bytes for the corpus
|
||||
f.Add([]byte{1, 0, 0, 0, 0, 0, 0, 0}) // Simple validator diff
|
||||
}
|
||||
}
|
||||
|
||||
// Add edge cases
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0x01, 0x02, 0x03, 0x04})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("newValidatorDiffs panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := newValidatorDiffs(data)
|
||||
_ = err
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzNewBalancesDiff tests balance diff deserialization
|
||||
func FuzzNewBalancesDiff(f *testing.F) {
|
||||
// Add seed corpus
|
||||
source, _ := util.DeterministicGenesisStateElectra(f, 8)
|
||||
target := source.Copy()
|
||||
balances := target.Balances()
|
||||
if len(balances) > 0 {
|
||||
balances[0] += 1000
|
||||
_ = target.SetBalances(balances)
|
||||
|
||||
// Create a simple diff for fuzzing - we'll just use raw bytes
|
||||
_, err := diffToBalances(source, target)
|
||||
if err == nil {
|
||||
// Add some realistic balance diff bytes for the corpus
|
||||
f.Add([]byte{1, 0, 0, 0, 0, 0, 0, 0}) // Simple balance diff
|
||||
}
|
||||
}
|
||||
|
||||
// Add edge cases
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("newBalancesDiff panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := newBalancesDiff(data)
|
||||
_ = err
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzApplyDiff tests applying variations of valid diffs
|
||||
func FuzzApplyDiff(f *testing.F) {
|
||||
// Test with realistic state variations, not random data
|
||||
ctx := context.Background()
|
||||
|
||||
// Add seed corpus with various valid scenarios
|
||||
sizes := []uint64{8, 16, 32, 64}
|
||||
for _, size := range sizes {
|
||||
source, _ := util.DeterministicGenesisStateElectra(f, size)
|
||||
target := source.Copy()
|
||||
|
||||
// Different types of realistic changes
|
||||
scenarios := []func(){
|
||||
func() { _ = target.SetSlot(source.Slot() + 1) }, // Slot change
|
||||
func() { // Balance change
|
||||
balances := target.Balances()
|
||||
if len(balances) > 0 {
|
||||
balances[0] += 1000000000 // 1 ETH
|
||||
_ = target.SetBalances(balances)
|
||||
}
|
||||
},
|
||||
func() { // Validator change
|
||||
validators := target.Validators()
|
||||
if len(validators) > 0 {
|
||||
validators[0].EffectiveBalance += 1000000000
|
||||
_ = target.SetValidators(validators)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for _, scenario := range scenarios {
|
||||
testTarget := source.Copy()
|
||||
scenario()
|
||||
|
||||
validDiff, err := Diff(source, testTarget)
|
||||
if err == nil {
|
||||
f.Add(validDiff.StateDiff, validDiff.ValidatorDiffs, validDiff.BalancesDiff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, stateDiff, validatorDiffs, balancesDiff []byte) {
|
||||
// Only test with reasonable sized inputs
|
||||
if len(stateDiff) > 10000 || len(validatorDiffs) > 10000 || len(balancesDiff) > 10000 {
|
||||
return
|
||||
}
|
||||
|
||||
// Create fresh source state for each test
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 8)
|
||||
|
||||
diff := HdiffBytes{
|
||||
StateDiff: stateDiff,
|
||||
ValidatorDiffs: validatorDiffs,
|
||||
BalancesDiff: balancesDiff,
|
||||
}
|
||||
|
||||
// Apply diff - errors are expected for fuzzed data
|
||||
_, err := ApplyDiff(ctx, source, diff)
|
||||
_ = err // Expected to fail with invalid data
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzReadPendingAttestation tests the pending attestation deserialization
|
||||
func FuzzReadPendingAttestation(f *testing.F) {
|
||||
// Add edge cases - this function is particularly vulnerable
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) // 8 bytes
|
||||
f.Add(make([]byte, 200)) // Larger than expected
|
||||
|
||||
// Add a case with large reported length
|
||||
largeLength := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(largeLength, 0xFFFFFFFF) // Large bits length
|
||||
f.Add(largeLength)
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("readPendingAttestation panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Make a copy since the function modifies the slice
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
_, err := readPendingAttestation(&dataCopy)
|
||||
_ = err
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzKmpIndex tests the KMP algorithm implementation
|
||||
func FuzzKmpIndex(f *testing.F) {
|
||||
// Test with integer pointers to match the actual usage
|
||||
f.Add(0, "1,2,3", "1,2,3,4,5")
|
||||
f.Add(3, "1,2,3", "1,2,3,1,2,3")
|
||||
f.Add(0, "", "1,2,3")
|
||||
|
||||
f.Fuzz(func(t *testing.T, lens int, patternStr string, textStr string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("kmpIndex panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Parse comma-separated strings into int slices
|
||||
var pattern, text []int
|
||||
if patternStr != "" {
|
||||
for _, s := range strings.Split(patternStr, ",") {
|
||||
if val, err := strconv.Atoi(strings.TrimSpace(s)); err == nil {
|
||||
pattern = append(pattern, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
if textStr != "" {
|
||||
for _, s := range strings.Split(textStr, ",") {
|
||||
if val, err := strconv.Atoi(strings.TrimSpace(s)); err == nil {
|
||||
text = append(text, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to pointer slices as used in actual code
|
||||
patternPtrs := make([]*int, len(pattern))
|
||||
for i := range pattern {
|
||||
val := pattern[i]
|
||||
patternPtrs[i] = &val
|
||||
}
|
||||
|
||||
textPtrs := make([]*int, len(text))
|
||||
for i := range text {
|
||||
val := text[i]
|
||||
textPtrs[i] = &val
|
||||
}
|
||||
|
||||
integerEquals := func(a, b *int) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
// Clamp lens to reasonable range to avoid infinite loops
|
||||
if lens < 0 {
|
||||
lens = 0
|
||||
}
|
||||
if lens > len(textPtrs) {
|
||||
lens = len(textPtrs)
|
||||
}
|
||||
|
||||
result := kmpIndex(lens, textPtrs, integerEquals)
|
||||
|
||||
// Basic sanity check
|
||||
if result < 0 || result > lens {
|
||||
t.Errorf("kmpIndex returned invalid result: %d for lens=%d", result, lens)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzComputeLPS tests the LPS computation for KMP
|
||||
func FuzzComputeLPS(f *testing.F) {
|
||||
// Add seed cases
|
||||
f.Add("1,2,1")
|
||||
f.Add("1,1,1")
|
||||
f.Add("1,2,3,4")
|
||||
f.Add("")
|
||||
|
||||
f.Fuzz(func(t *testing.T, patternStr string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("computeLPS panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Parse comma-separated string into int slice
|
||||
var pattern []int
|
||||
if patternStr != "" {
|
||||
for _, s := range strings.Split(patternStr, ",") {
|
||||
if val, err := strconv.Atoi(strings.TrimSpace(s)); err == nil {
|
||||
pattern = append(pattern, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to pointer slice
|
||||
patternPtrs := make([]*int, len(pattern))
|
||||
for i := range pattern {
|
||||
val := pattern[i]
|
||||
patternPtrs[i] = &val
|
||||
}
|
||||
|
||||
integerEquals := func(a, b *int) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
result := computeLPS(patternPtrs, integerEquals)
|
||||
|
||||
// Verify result length matches input
|
||||
if len(result) != len(pattern) {
|
||||
t.Errorf("computeLPS returned wrong length: got %d, expected %d", len(result), len(pattern))
|
||||
}
|
||||
|
||||
// Verify all LPS values are non-negative and within bounds
|
||||
for i, lps := range result {
|
||||
if lps < 0 || lps > i {
|
||||
t.Errorf("Invalid LPS value at index %d: %d", i, lps)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzDiffToBalances tests balance diff computation
|
||||
func FuzzDiffToBalances(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, sourceData, targetData []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("diffToBalances panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Convert byte data to balance arrays
|
||||
var sourceBalances, targetBalances []uint64
|
||||
|
||||
// Parse source balances (8 bytes per uint64)
|
||||
for i := 0; i+7 < len(sourceData) && len(sourceBalances) < 100; i += 8 {
|
||||
balance := binary.LittleEndian.Uint64(sourceData[i : i+8])
|
||||
sourceBalances = append(sourceBalances, balance)
|
||||
}
|
||||
|
||||
// Parse target balances
|
||||
for i := 0; i+7 < len(targetData) && len(targetBalances) < 100; i += 8 {
|
||||
balance := binary.LittleEndian.Uint64(targetData[i : i+8])
|
||||
targetBalances = append(targetBalances, balance)
|
||||
}
|
||||
|
||||
// Create states with the provided balances
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 1)
|
||||
target, _ := util.DeterministicGenesisStateElectra(t, 1)
|
||||
|
||||
if len(sourceBalances) > 0 {
|
||||
_ = source.SetBalances(sourceBalances)
|
||||
}
|
||||
if len(targetBalances) > 0 {
|
||||
_ = target.SetBalances(targetBalances)
|
||||
}
|
||||
|
||||
result, err := diffToBalances(source, target)
|
||||
|
||||
// If no error, verify result consistency
|
||||
if err == nil && len(result) > 0 {
|
||||
// Result length should match target length
|
||||
if len(result) != len(target.Balances()) {
|
||||
t.Errorf("diffToBalances result length mismatch: got %d, expected %d",
|
||||
len(result), len(target.Balances()))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// FuzzValidatorsEqual tests validator comparison
|
||||
func FuzzValidatorsEqual(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("validatorsEqual panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create two validators and fuzz their fields
|
||||
if len(data) < 16 {
|
||||
return
|
||||
}
|
||||
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 2)
|
||||
validators := source.Validators()
|
||||
if len(validators) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
val1 := validators[0]
|
||||
val2 := validators[1]
|
||||
|
||||
// Modify validator fields based on fuzz data
|
||||
if len(data) > 0 && data[0]%2 == 0 {
|
||||
val2.EffectiveBalance = val1.EffectiveBalance + uint64(data[0])
|
||||
}
|
||||
if len(data) > 1 && data[1]%2 == 0 {
|
||||
val2.Slashed = !val1.Slashed
|
||||
}
|
||||
|
||||
// Create ReadOnlyValidator wrappers if needed
|
||||
// Since validatorsEqual expects ReadOnlyValidator interface,
|
||||
// we'll skip this test for now as it requires state wrapper implementation
|
||||
_ = val1
|
||||
_ = val2
|
||||
})
|
||||
}
|
||||
391
consensus-types/hdiff/property_test.go
Normal file
391
consensus-types/hdiff/property_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package hdiff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/consensus-types/primitives"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/require"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/util"
|
||||
)
|
||||
|
||||
// PropertyTestRoundTrip verifies that diff->apply is idempotent with realistic data
|
||||
func FuzzPropertyRoundTrip(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, slotDelta uint64, balanceData []byte, validatorData []byte) {
|
||||
// Limit to realistic ranges
|
||||
if slotDelta > 32 { // Max one epoch
|
||||
slotDelta = slotDelta % 32
|
||||
}
|
||||
|
||||
// Convert byte data to realistic deltas and changes
|
||||
var balanceDeltas []int64
|
||||
var validatorChanges []bool
|
||||
|
||||
// Parse balance deltas - limit to realistic amounts (8 bytes per int64)
|
||||
for i := 0; i+7 < len(balanceData) && len(balanceDeltas) < 20; i += 8 {
|
||||
delta := int64(binary.LittleEndian.Uint64(balanceData[i : i+8]))
|
||||
// Keep deltas realistic (max 10 ETH change)
|
||||
if delta > 10000000000 {
|
||||
delta = delta % 10000000000
|
||||
}
|
||||
if delta < -10000000000 {
|
||||
delta = -((-delta) % 10000000000)
|
||||
}
|
||||
balanceDeltas = append(balanceDeltas, delta)
|
||||
}
|
||||
|
||||
// Parse validator changes (1 byte per bool) - limit to small number
|
||||
for i := 0; i < len(validatorData) && len(validatorChanges) < 10; i++ {
|
||||
validatorChanges = append(validatorChanges, validatorData[i]%2 == 0)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create source state with reasonable size
|
||||
validatorCount := uint64(len(validatorChanges) + 8) // Minimum 8 validators
|
||||
if validatorCount > 64 {
|
||||
validatorCount = 64 // Cap at 64 for performance
|
||||
}
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, validatorCount)
|
||||
|
||||
// Create target state with modifications
|
||||
target := source.Copy()
|
||||
|
||||
// Apply slot change
|
||||
_ = target.SetSlot(source.Slot() + primitives.Slot(slotDelta))
|
||||
|
||||
// Apply realistic balance changes
|
||||
if len(balanceDeltas) > 0 {
|
||||
balances := target.Balances()
|
||||
for i, delta := range balanceDeltas {
|
||||
if i >= len(balances) {
|
||||
break
|
||||
}
|
||||
// Apply realistic balance changes with safe bounds
|
||||
if delta < 0 {
|
||||
if uint64(-delta) > balances[i] {
|
||||
balances[i] = 0 // Can't go below 0
|
||||
} else {
|
||||
balances[i] -= uint64(-delta)
|
||||
}
|
||||
} else {
|
||||
// Cap at reasonable maximum (1000 ETH)
|
||||
maxBalance := uint64(1000000000000) // 1000 ETH in Gwei
|
||||
if balances[i]+uint64(delta) > maxBalance {
|
||||
balances[i] = maxBalance
|
||||
} else {
|
||||
balances[i] += uint64(delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = target.SetBalances(balances)
|
||||
}
|
||||
|
||||
// Apply realistic validator changes
|
||||
if len(validatorChanges) > 0 {
|
||||
validators := target.Validators()
|
||||
for i, shouldChange := range validatorChanges {
|
||||
if i >= len(validators) {
|
||||
break
|
||||
}
|
||||
if shouldChange {
|
||||
// Make realistic changes - small effective balance adjustments
|
||||
validators[i].EffectiveBalance += 1000000000 // 1 ETH
|
||||
}
|
||||
}
|
||||
_ = target.SetValidators(validators)
|
||||
}
|
||||
|
||||
// Create diff
|
||||
diff, err := Diff(source, target)
|
||||
if err != nil {
|
||||
// If diff creation fails, that's acceptable for malformed inputs
|
||||
return
|
||||
}
|
||||
|
||||
// Apply diff
|
||||
result, err := ApplyDiff(ctx, source, diff)
|
||||
if err != nil {
|
||||
// If diff application fails, that's acceptable
|
||||
return
|
||||
}
|
||||
|
||||
// Verify round-trip property: source + diff = target
|
||||
require.Equal(t, target.Slot(), result.Slot())
|
||||
|
||||
// Verify balance consistency
|
||||
targetBalances := target.Balances()
|
||||
resultBalances := result.Balances()
|
||||
require.Equal(t, len(targetBalances), len(resultBalances))
|
||||
for i := range targetBalances {
|
||||
require.Equal(t, targetBalances[i], resultBalances[i], "Balance mismatch at index %d", i)
|
||||
}
|
||||
|
||||
// Verify validator consistency
|
||||
targetVals := target.Validators()
|
||||
resultVals := result.Validators()
|
||||
require.Equal(t, len(targetVals), len(resultVals))
|
||||
for i := range targetVals {
|
||||
require.Equal(t, targetVals[i].Slashed, resultVals[i].Slashed, "Validator slashing mismatch at index %d", i)
|
||||
require.Equal(t, targetVals[i].EffectiveBalance, resultVals[i].EffectiveBalance, "Validator balance mismatch at index %d", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PropertyTestReasonablePerformance verifies operations complete quickly with realistic data
|
||||
func FuzzPropertyResourceBounds(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, validatorCount uint8, slotDelta uint8, changeCount uint8) {
|
||||
// Use realistic parameters
|
||||
validators := uint64(validatorCount%64 + 8) // 8-71 validators
|
||||
slots := uint64(slotDelta % 32) // 0-31 slots
|
||||
changes := int(changeCount % 10) // 0-9 changes
|
||||
|
||||
// Create realistic states
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, validators)
|
||||
target := source.Copy()
|
||||
|
||||
// Apply realistic changes
|
||||
_ = target.SetSlot(source.Slot() + primitives.Slot(slots))
|
||||
|
||||
if changes > 0 {
|
||||
validatorList := target.Validators()
|
||||
for i := 0; i < changes && i < len(validatorList); i++ {
|
||||
validatorList[i].EffectiveBalance += 1000000000 // 1 ETH
|
||||
}
|
||||
_ = target.SetValidators(validatorList)
|
||||
}
|
||||
|
||||
// Operations should complete quickly
|
||||
start := time.Now()
|
||||
diff, err := Diff(source, target)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
// Should be fast
|
||||
require.Equal(t, true, duration < time.Second, "Diff creation too slow: %v", duration)
|
||||
|
||||
// Apply should also be fast
|
||||
start = time.Now()
|
||||
_, err = ApplyDiff(context.Background(), source, diff)
|
||||
duration = time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
require.Equal(t, true, duration < time.Second, "Diff application too slow: %v", duration)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PropertyTestDiffSize verifies that diffs are smaller than full states for typical cases
|
||||
func FuzzPropertyDiffEfficiency(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, slotDelta uint64, numChanges uint8) {
|
||||
if slotDelta > 100 {
|
||||
slotDelta = slotDelta % 100
|
||||
}
|
||||
if numChanges > 10 {
|
||||
numChanges = numChanges % 10
|
||||
}
|
||||
|
||||
// Create states with small differences
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 64)
|
||||
target := source.Copy()
|
||||
|
||||
_ = target.SetSlot(source.Slot() + primitives.Slot(slotDelta))
|
||||
|
||||
// Make a few small changes
|
||||
if numChanges > 0 {
|
||||
validators := target.Validators()
|
||||
for i := uint8(0); i < numChanges && int(i) < len(validators); i++ {
|
||||
validators[i].EffectiveBalance += 1000
|
||||
}
|
||||
_ = target.SetValidators(validators)
|
||||
}
|
||||
|
||||
// Create diff
|
||||
diff, err := Diff(source, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For small changes, diff should be much smaller than full state
|
||||
sourceSSZ, err := source.MarshalSSZ()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
diffSize := len(diff.StateDiff) + len(diff.ValidatorDiffs) + len(diff.BalancesDiff)
|
||||
|
||||
// Diff should be smaller than full state for small changes
|
||||
if numChanges <= 5 && slotDelta <= 10 {
|
||||
require.Equal(t, true, diffSize < len(sourceSSZ)/2,
|
||||
"Diff size %d should be less than half of state size %d", diffSize, len(sourceSSZ))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PropertyTestBalanceConservation verifies that balance operations don't create/destroy value unexpectedly
|
||||
func FuzzPropertyBalanceConservation(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, balanceData []byte) {
|
||||
// Convert byte data to balance changes
|
||||
var balanceChanges []int64
|
||||
for i := 0; i+7 < len(balanceData) && len(balanceChanges) < 50; i += 8 {
|
||||
change := int64(binary.LittleEndian.Uint64(balanceData[i : i+8]))
|
||||
balanceChanges = append(balanceChanges, change)
|
||||
}
|
||||
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, uint64(len(balanceChanges)+10))
|
||||
originalBalances := source.Balances()
|
||||
|
||||
// Calculate total before
|
||||
var totalBefore uint64
|
||||
for _, balance := range originalBalances {
|
||||
totalBefore += balance
|
||||
}
|
||||
|
||||
// Apply balance changes via diff system
|
||||
target := source.Copy()
|
||||
targetBalances := target.Balances()
|
||||
|
||||
var totalDelta int64
|
||||
for i, delta := range balanceChanges {
|
||||
if i >= len(targetBalances) {
|
||||
break
|
||||
}
|
||||
|
||||
// Prevent underflow
|
||||
if delta < 0 && uint64(-delta) > targetBalances[i] {
|
||||
totalDelta += int64(targetBalances[i]) // Lost amount
|
||||
targetBalances[i] = 0
|
||||
} else if delta < 0 {
|
||||
targetBalances[i] -= uint64(-delta)
|
||||
totalDelta += delta
|
||||
} else {
|
||||
// Prevent overflow
|
||||
if uint64(delta) > math.MaxUint64-targetBalances[i] {
|
||||
gained := math.MaxUint64 - targetBalances[i]
|
||||
totalDelta += int64(gained)
|
||||
targetBalances[i] = math.MaxUint64
|
||||
} else {
|
||||
targetBalances[i] += uint64(delta)
|
||||
totalDelta += delta
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = target.SetBalances(targetBalances)
|
||||
|
||||
// Apply through diff system
|
||||
diff, err := Diff(source, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate total after
|
||||
resultBalances := result.Balances()
|
||||
var totalAfter uint64
|
||||
for _, balance := range resultBalances {
|
||||
totalAfter += balance
|
||||
}
|
||||
|
||||
// Verify conservation (accounting for intended changes)
|
||||
expectedTotal := totalBefore
|
||||
if totalDelta >= 0 {
|
||||
expectedTotal += uint64(totalDelta)
|
||||
} else {
|
||||
if uint64(-totalDelta) <= expectedTotal {
|
||||
expectedTotal -= uint64(-totalDelta)
|
||||
} else {
|
||||
expectedTotal = 0
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, expectedTotal, totalAfter,
|
||||
"Balance conservation violated: before=%d, delta=%d, expected=%d, actual=%d",
|
||||
totalBefore, totalDelta, expectedTotal, totalAfter)
|
||||
})
|
||||
}
|
||||
|
||||
// PropertyTestMonotonicSlot verifies slot only increases
|
||||
func FuzzPropertyMonotonicSlot(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, slotDelta uint64) {
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 16)
|
||||
target := source.Copy()
|
||||
|
||||
targetSlot := source.Slot() + primitives.Slot(slotDelta)
|
||||
_ = target.SetSlot(targetSlot)
|
||||
|
||||
diff, err := Diff(source, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Slot should never decrease
|
||||
require.Equal(t, true, result.Slot() >= source.Slot(),
|
||||
"Slot decreased from %d to %d", source.Slot(), result.Slot())
|
||||
|
||||
// Slot should match target
|
||||
require.Equal(t, targetSlot, result.Slot())
|
||||
})
|
||||
}
|
||||
|
||||
// PropertyTestValidatorIndexIntegrity verifies validator indices remain consistent
|
||||
func FuzzPropertyValidatorIndices(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, changeData []byte) {
|
||||
// Convert byte data to boolean changes
|
||||
var changes []bool
|
||||
for i := 0; i < len(changeData) && len(changes) < 20; i++ {
|
||||
changes = append(changes, changeData[i]%2 == 0)
|
||||
}
|
||||
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, uint64(len(changes)+5))
|
||||
target := source.Copy()
|
||||
|
||||
// Apply changes
|
||||
validators := target.Validators()
|
||||
for i, shouldChange := range changes {
|
||||
if i >= len(validators) {
|
||||
break
|
||||
}
|
||||
if shouldChange {
|
||||
validators[i].EffectiveBalance += 1000
|
||||
}
|
||||
}
|
||||
_ = target.SetValidators(validators)
|
||||
|
||||
diff, err := Diff(source, target)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Validator count should not decrease
|
||||
require.Equal(t, true, len(result.Validators()) >= len(source.Validators()),
|
||||
"Validator count decreased from %d to %d", len(source.Validators()), len(result.Validators()))
|
||||
|
||||
// Public keys should be preserved for existing validators
|
||||
sourceVals := source.Validators()
|
||||
resultVals := result.Validators()
|
||||
for i := range sourceVals {
|
||||
if i < len(resultVals) {
|
||||
require.Equal(t, sourceVals[i].PublicKey, resultVals[i].PublicKey,
|
||||
"Public key changed at validator index %d", i)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
393
consensus-types/hdiff/security_test.go
Normal file
393
consensus-types/hdiff/security_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package hdiff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v6/testing/require"
|
||||
"github.com/OffchainLabs/prysm/v6/testing/util"
|
||||
)
|
||||
|
||||
// TestIntegerOverflowProtection tests protection against balance overflow attacks
|
||||
func TestIntegerOverflowProtection(t *testing.T) {
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 8)
|
||||
|
||||
// Test balance overflow in diffToBalances - use realistic values
|
||||
t.Run("balance_diff_overflow", func(t *testing.T) {
|
||||
target := source.Copy()
|
||||
balances := target.Balances()
|
||||
|
||||
// Set high but realistic balance values (32 ETH in Gwei = 32e9)
|
||||
balances[0] = 32000000000 // 32 ETH
|
||||
balances[1] = 64000000000 // 64 ETH
|
||||
_ = target.SetBalances(balances)
|
||||
|
||||
// This should work fine with realistic values
|
||||
diffs, err := diffToBalances(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the diffs are reasonable
|
||||
require.Equal(t, true, len(diffs) > 0, "Should have balance diffs")
|
||||
})
|
||||
|
||||
// Test reasonable balance changes
|
||||
t.Run("realistic_balance_changes", func(t *testing.T) {
|
||||
// Create realistic balance changes (slashing, rewards)
|
||||
balancesDiff := []int64{1000000000, -500000000, 2000000000} // 1 ETH gain, 0.5 ETH loss, 2 ETH gain
|
||||
|
||||
// Apply to state with normal balances
|
||||
testSource := source.Copy()
|
||||
normalBalances := []uint64{32000000000, 32000000000, 32000000000} // 32 ETH each
|
||||
_ = testSource.SetBalances(normalBalances)
|
||||
|
||||
// This should work fine
|
||||
result, err := applyBalancesDiff(testSource, balancesDiff)
|
||||
require.NoError(t, err)
|
||||
|
||||
resultBalances := result.Balances()
|
||||
require.Equal(t, uint64(33000000000), resultBalances[0]) // 33 ETH
|
||||
require.Equal(t, uint64(31500000000), resultBalances[1]) // 31.5 ETH
|
||||
require.Equal(t, uint64(34000000000), resultBalances[2]) // 34 ETH
|
||||
})
|
||||
}
|
||||
|
||||
// TestReasonablePerformance tests that operations complete in reasonable time
|
||||
func TestReasonablePerformance(t *testing.T) {
|
||||
t.Run("large_state_performance", func(t *testing.T) {
|
||||
// Test with a large but realistic validator set
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 1000) // 1000 validators
|
||||
target := source.Copy()
|
||||
|
||||
// Make realistic changes
|
||||
_ = target.SetSlot(source.Slot() + 32) // One epoch
|
||||
validators := target.Validators()
|
||||
for i := 0; i < 100; i++ { // 10% of validators changed
|
||||
validators[i].EffectiveBalance += 1000000000 // 1 ETH change
|
||||
}
|
||||
_ = target.SetValidators(validators)
|
||||
|
||||
// Should complete quickly
|
||||
start := time.Now()
|
||||
diff, err := Diff(source, target)
|
||||
duration := time.Since(start)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, duration < time.Second, "Diff creation took too long: %v", duration)
|
||||
require.Equal(t, true, len(diff.StateDiff) > 0, "Should have state diff")
|
||||
})
|
||||
|
||||
t.Run("realistic_diff_application", func(t *testing.T) {
|
||||
// Test applying diffs to large states
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 500)
|
||||
target := source.Copy()
|
||||
_ = target.SetSlot(source.Slot() + 1)
|
||||
|
||||
// Create and apply diff
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
start := time.Now()
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
duration := time.Since(start)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, target.Slot(), result.Slot())
|
||||
require.Equal(t, true, duration < time.Second, "Diff application took too long: %v", duration)
|
||||
})
|
||||
}
|
||||
|
||||
// TestStateTransitionValidation tests realistic state transition scenarios
|
||||
func TestStateTransitionValidation(t *testing.T) {
|
||||
t.Run("validator_slashing_scenario", func(t *testing.T) {
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 10)
|
||||
target := source.Copy()
|
||||
|
||||
// Simulate validator slashing (realistic scenario)
|
||||
validators := target.Validators()
|
||||
validators[0].Slashed = true
|
||||
validators[0].EffectiveBalance = 0 // Slashed validator loses balance
|
||||
_ = target.SetValidators(validators)
|
||||
|
||||
// This should work fine
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, result.Validators()[0].Slashed)
|
||||
require.Equal(t, uint64(0), result.Validators()[0].EffectiveBalance)
|
||||
})
|
||||
|
||||
t.Run("epoch_transition_scenario", func(t *testing.T) {
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 64)
|
||||
target := source.Copy()
|
||||
|
||||
// Simulate epoch transition with multiple changes
|
||||
_ = target.SetSlot(source.Slot() + 32) // One epoch
|
||||
|
||||
// Some validators get rewards, others get penalties
|
||||
balances := target.Balances()
|
||||
for i := 0; i < len(balances); i++ {
|
||||
if i%2 == 0 {
|
||||
balances[i] += 100000000 // 0.1 ETH reward
|
||||
} else {
|
||||
if balances[i] > 50000000 {
|
||||
balances[i] -= 50000000 // 0.05 ETH penalty
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = target.SetBalances(balances)
|
||||
|
||||
// This should work smoothly
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, target.Slot(), result.Slot())
|
||||
})
|
||||
|
||||
t.Run("consistent_state_root", func(t *testing.T) {
|
||||
// Test that diffs preserve state consistency
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 32)
|
||||
target := source.Copy()
|
||||
|
||||
// Make minimal changes
|
||||
_ = target.SetSlot(source.Slot() + 1)
|
||||
|
||||
// Diff and apply should be consistent
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Result should match target
|
||||
require.Equal(t, target.Slot(), result.Slot())
|
||||
require.Equal(t, len(target.Validators()), len(result.Validators()))
|
||||
require.Equal(t, len(target.Balances()), len(result.Balances()))
|
||||
})
|
||||
}
|
||||
|
||||
// TestSerializationRoundTrip tests serialization consistency
|
||||
func TestSerializationRoundTrip(t *testing.T) {
|
||||
t.Run("diff_serialization_consistency", func(t *testing.T) {
|
||||
// Test that serialization and deserialization are consistent
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 16)
|
||||
target := source.Copy()
|
||||
|
||||
// Make changes
|
||||
_ = target.SetSlot(source.Slot() + 5)
|
||||
validators := target.Validators()
|
||||
validators[0].EffectiveBalance += 1000000000
|
||||
_ = target.SetValidators(validators)
|
||||
|
||||
// Create diff
|
||||
diff1, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deserialize and re-serialize
|
||||
hdiff, err := newHdiff(diff1)
|
||||
require.NoError(t, err)
|
||||
|
||||
diff2 := hdiff.serialize()
|
||||
|
||||
// Apply both diffs - should get same result
|
||||
result1, err := ApplyDiff(context.Background(), source, diff1)
|
||||
require.NoError(t, err)
|
||||
|
||||
result2, err := ApplyDiff(context.Background(), source, diff2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, result1.Slot(), result2.Slot())
|
||||
require.Equal(t, result1.Validators()[0].EffectiveBalance, result2.Validators()[0].EffectiveBalance)
|
||||
})
|
||||
|
||||
t.Run("empty_diff_handling", func(t *testing.T) {
|
||||
// Test that empty diffs are handled correctly
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 8)
|
||||
target := source.Copy() // No changes
|
||||
|
||||
// Should create minimal diff
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply should work and return equivalent state
|
||||
result, err := ApplyDiff(context.Background(), source, diff)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, source.Slot(), result.Slot())
|
||||
require.Equal(t, len(source.Validators()), len(result.Validators()))
|
||||
})
|
||||
|
||||
t.Run("compression_efficiency", func(t *testing.T) {
|
||||
// Test that compression is working effectively
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 100)
|
||||
target := source.Copy()
|
||||
|
||||
// Make small changes
|
||||
_ = target.SetSlot(source.Slot() + 1)
|
||||
validators := target.Validators()
|
||||
validators[0].EffectiveBalance += 1000000000
|
||||
_ = target.SetValidators(validators)
|
||||
|
||||
// Create diff
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get full state size
|
||||
fullStateSSZ, err := target.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Diff should be much smaller than full state
|
||||
diffSize := len(diff.StateDiff) + len(diff.ValidatorDiffs) + len(diff.BalancesDiff)
|
||||
require.Equal(t, true, diffSize < len(fullStateSSZ)/2,
|
||||
"Diff should be smaller than full state: diff=%d, full=%d", diffSize, len(fullStateSSZ))
|
||||
})
|
||||
}
|
||||
|
||||
// TestKMPSecurity tests the KMP algorithm for security issues
|
||||
func TestKMPSecurity(t *testing.T) {
|
||||
t.Run("nil_pointer_handling", func(t *testing.T) {
|
||||
// Test with nil pointers in the pattern/text
|
||||
pattern := []*int{nil, nil, nil}
|
||||
text := []*int{nil, nil, nil, nil, nil}
|
||||
|
||||
equals := func(a, b *int) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
// Should not panic - result can be any integer
|
||||
result := kmpIndex(len(pattern), text, equals)
|
||||
_ = result // Any result is valid, just ensure no panic
|
||||
})
|
||||
|
||||
t.Run("empty_pattern_edge_case", func(t *testing.T) {
|
||||
var pattern []*int
|
||||
text := []*int{new(int), new(int)}
|
||||
|
||||
equals := func(a, b *int) bool { return a == b }
|
||||
|
||||
result := kmpIndex(0, text, equals)
|
||||
require.Equal(t, 0, result, "Empty pattern should return 0")
|
||||
_ = pattern // Silence unused variable warning
|
||||
})
|
||||
|
||||
t.Run("realistic_pattern_performance", func(t *testing.T) {
|
||||
// Test with realistic sizes to ensure good performance
|
||||
realisticSize := 100 // More realistic for validator arrays
|
||||
pattern := make([]*int, realisticSize)
|
||||
text := make([]*int, realisticSize*2)
|
||||
|
||||
// Create realistic pattern
|
||||
for i := range pattern {
|
||||
val := i % 10 // More variation
|
||||
pattern[i] = &val
|
||||
}
|
||||
for i := range text {
|
||||
val := i % 10
|
||||
text[i] = &val
|
||||
}
|
||||
|
||||
equals := func(a, b *int) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return *a == *b
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := kmpIndex(len(pattern), text, equals)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should complete quickly with realistic inputs
|
||||
require.Equal(t, true, duration < time.Second,
|
||||
"KMP took too long: %v", duration)
|
||||
_ = result // Any result is valid, just ensure performance is good
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrencySafety tests thread safety of the hdiff operations
|
||||
func TestConcurrencySafety(t *testing.T) {
|
||||
t.Run("concurrent_diff_creation", func(t *testing.T) {
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 32)
|
||||
target := source.Copy()
|
||||
_ = target.SetSlot(source.Slot() + 1)
|
||||
|
||||
const numGoroutines = 10
|
||||
const iterations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*iterations)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < iterations; j++ {
|
||||
_, err := Diff(source, target)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("worker %d iteration %d: %v", workerID, j, err)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent_diff_application", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
source, _ := util.DeterministicGenesisStateElectra(t, 16)
|
||||
target := source.Copy()
|
||||
_ = target.SetSlot(source.Slot() + 5)
|
||||
|
||||
diff, err := Diff(source, target)
|
||||
require.NoError(t, err)
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Each goroutine needs its own copy of the source state
|
||||
localSource := source.Copy()
|
||||
_, err := ApplyDiff(ctx, localSource, diff)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("worker %d: %v", workerID, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
2040
consensus-types/hdiff/state_diff.go
Normal file
2040
consensus-types/hdiff/state_diff.go
Normal file
File diff suppressed because it is too large
Load Diff
1195
consensus-types/hdiff/state_diff_test.go
Normal file
1195
consensus-types/hdiff/state_diff_test.go
Normal file
File diff suppressed because it is too large
Load Diff
9
consensus-types/helpers/BUILD.bazel
Normal file
9
consensus-types/helpers/BUILD.bazel
Normal file
@@ -0,0 +1,9 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["comparisons.go"],
|
||||
importpath = "github.com/OffchainLabs/prysm/v6/consensus-types/helpers",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//proto/prysm/v1alpha1:go_default_library"],
|
||||
)
|
||||
109
consensus-types/helpers/comparisons.go
Normal file
109
consensus-types/helpers/comparisons.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1"
|
||||
)
|
||||
|
||||
func ForksEqual(s, t *ethpb.Fork) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
if s.Epoch != t.Epoch {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.PreviousVersion, t.PreviousVersion) {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(s.CurrentVersion, t.CurrentVersion)
|
||||
}
|
||||
|
||||
func BlockHeadersEqual(s, t *ethpb.BeaconBlockHeader) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
if s.Slot != t.Slot {
|
||||
return false
|
||||
}
|
||||
if s.ProposerIndex != t.ProposerIndex {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.ParentRoot, t.ParentRoot) {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.StateRoot, t.StateRoot) {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(s.BodyRoot, t.BodyRoot)
|
||||
}
|
||||
|
||||
func Eth1DataEqual(s, t *ethpb.Eth1Data) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.DepositRoot, t.DepositRoot) {
|
||||
return false
|
||||
}
|
||||
if s.DepositCount != t.DepositCount {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(s.BlockHash, t.BlockHash)
|
||||
}
|
||||
|
||||
func PendingDepositsEqual(s, t *ethpb.PendingDeposit) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.PublicKey, t.PublicKey) {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.WithdrawalCredentials, t.WithdrawalCredentials) {
|
||||
return false
|
||||
}
|
||||
if s.Amount != t.Amount {
|
||||
return false
|
||||
}
|
||||
if !bytes.Equal(s.Signature, t.Signature) {
|
||||
return false
|
||||
}
|
||||
return s.Slot == t.Slot
|
||||
}
|
||||
|
||||
func PendingPartialWithdrawalsEqual(s, t *ethpb.PendingPartialWithdrawal) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
if s.Index != t.Index {
|
||||
return false
|
||||
}
|
||||
if s.Amount != t.Amount {
|
||||
return false
|
||||
}
|
||||
return s.WithdrawableEpoch == t.WithdrawableEpoch
|
||||
}
|
||||
|
||||
func PendingConsolidationsEqual(s, t *ethpb.PendingConsolidation) bool {
|
||||
if s == nil && t == nil {
|
||||
return true
|
||||
}
|
||||
if s == nil || t == nil {
|
||||
return false
|
||||
}
|
||||
return s.SourceIndex == t.SourceIndex && s.TargetIndex == t.TargetIndex
|
||||
}
|
||||
Reference in New Issue
Block a user