diff --git a/beacon-chain/state/v3/BUILD.bazel b/beacon-chain/state/v3/BUILD.bazel index 32320e1809..94784f1858 100644 --- a/beacon-chain/state/v3/BUILD.bazel +++ b/beacon-chain/state/v3/BUILD.bazel @@ -42,6 +42,7 @@ go_library( "//beacon-chain/state/v1:go_default_library", "//config/features:go_default_library", "//config/params:go_default_library", + "//container/slice:go_default_library", "//crypto/hash:go_default_library", "//encoding/bytesutil:go_default_library", "//encoding/ssz:go_default_library", @@ -53,6 +54,7 @@ go_library( "@com_github_prometheus_client_golang//prometheus/promauto:go_default_library", "@com_github_prysmaticlabs_eth2_types//:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", + "@io_opencensus_go//trace:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", ], ) @@ -66,13 +68,21 @@ go_test( "getters_block_test.go", "getters_test.go", "getters_validator_test.go", + "setters_test.go", + "state_trie_test.go", ], embed = [":go_default_library"], deps = [ + "//beacon-chain/state/stateutil:go_default_library", + "//beacon-chain/state/types:go_default_library", "//beacon-chain/state/v1:go_default_library", + "//config/features:go_default_library", + "//config/params:go_default_library", "//encoding/bytesutil:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//testing/assert:go_default_library", "//testing/require:go_default_library", + "@com_github_prysmaticlabs_eth2_types//:go_default_library", + "@com_github_prysmaticlabs_go_bitfield//:go_default_library", ], ) diff --git a/beacon-chain/state/v3/field_roots.go b/beacon-chain/state/v3/field_roots.go index 827dcb3f2a..5accf3ba3d 100644 --- a/beacon-chain/state/v3/field_roots.go +++ b/beacon-chain/state/v3/field_roots.go @@ -1,6 +1,7 @@ package v3 import ( + "context" "encoding/binary" "sync" @@ -13,6 +14,7 @@ import ( "github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "go.opencensus.io/trace" ) var ( @@ -47,15 +49,17 @@ type stateRootHasher struct { // computeFieldRoots returns the hash tree root computations of every field in // the beacon state as a list of 32 byte roots. -//nolint:deadcode -func computeFieldRoots(state *ethpb.BeaconStateMerge) ([][]byte, error) { +func computeFieldRoots(ctx context.Context, state *ethpb.BeaconStateMerge) ([][]byte, error) { if features.Get().EnableSSZCache { - return cachedHasher.computeFieldRootsWithHasher(state) + return cachedHasher.computeFieldRootsWithHasher(ctx, state) } - return nocachedHasher.computeFieldRootsWithHasher(state) + return nocachedHasher.computeFieldRootsWithHasher(ctx, state) } -func (h *stateRootHasher) computeFieldRootsWithHasher(state *ethpb.BeaconStateMerge) ([][]byte, error) { +func (h *stateRootHasher) computeFieldRootsWithHasher(ctx context.Context, state *ethpb.BeaconStateMerge) ([][]byte, error) { + ctx, span := trace.StartSpan(ctx, "beaconState.computeFieldRootsWithHasher") + defer span.End() + if state == nil { return nil, errors.New("nil state") } @@ -219,8 +223,11 @@ func (h *stateRootHasher) computeFieldRootsWithHasher(state *ethpb.BeaconStateMe fieldRoots[23] = nextSyncCommitteeRoot[:] // Execution payload root. - //TODO: Blocked by https://github.com/ferranbt/fastssz/pull/65 - fieldRoots[24] = []byte{} + executionPayloadRoot, err := state.LatestExecutionPayloadHeader.HashTreeRoot() + if err != nil { + return nil, err + } + fieldRoots[24] = executionPayloadRoot[:] return fieldRoots, nil } diff --git a/beacon-chain/state/v3/getters_state.go b/beacon-chain/state/v3/getters_state.go index 9bd446c0b9..71604daae4 100644 --- a/beacon-chain/state/v3/getters_state.go +++ b/beacon-chain/state/v3/getters_state.go @@ -23,31 +23,32 @@ func (b *BeaconState) CloneInnerState() interface{} { b.lock.RLock() defer b.lock.RUnlock() - return ðpb.BeaconStateAltair{ - GenesisTime: b.genesisTime(), - GenesisValidatorsRoot: b.genesisValidatorRoot(), - Slot: b.slot(), - Fork: b.fork(), - LatestBlockHeader: b.latestBlockHeader(), - BlockRoots: b.blockRoots(), - StateRoots: b.stateRoots(), - HistoricalRoots: b.historicalRoots(), - Eth1Data: b.eth1Data(), - Eth1DataVotes: b.eth1DataVotes(), - Eth1DepositIndex: b.eth1DepositIndex(), - Validators: b.validators(), - Balances: b.balances(), - RandaoMixes: b.randaoMixes(), - Slashings: b.slashings(), - CurrentEpochParticipation: b.currentEpochParticipation(), - PreviousEpochParticipation: b.previousEpochParticipation(), - JustificationBits: b.justificationBits(), - PreviousJustifiedCheckpoint: b.previousJustifiedCheckpoint(), - CurrentJustifiedCheckpoint: b.currentJustifiedCheckpoint(), - FinalizedCheckpoint: b.finalizedCheckpoint(), - InactivityScores: b.inactivityScores(), - CurrentSyncCommittee: b.currentSyncCommittee(), - NextSyncCommittee: b.nextSyncCommittee(), + return ðpb.BeaconStateMerge{ + GenesisTime: b.genesisTime(), + GenesisValidatorsRoot: b.genesisValidatorRoot(), + Slot: b.slot(), + Fork: b.fork(), + LatestBlockHeader: b.latestBlockHeader(), + BlockRoots: b.blockRoots(), + StateRoots: b.stateRoots(), + HistoricalRoots: b.historicalRoots(), + Eth1Data: b.eth1Data(), + Eth1DataVotes: b.eth1DataVotes(), + Eth1DepositIndex: b.eth1DepositIndex(), + Validators: b.validators(), + Balances: b.balances(), + RandaoMixes: b.randaoMixes(), + Slashings: b.slashings(), + CurrentEpochParticipation: b.currentEpochParticipation(), + PreviousEpochParticipation: b.previousEpochParticipation(), + JustificationBits: b.justificationBits(), + PreviousJustifiedCheckpoint: b.previousJustifiedCheckpoint(), + CurrentJustifiedCheckpoint: b.currentJustifiedCheckpoint(), + FinalizedCheckpoint: b.finalizedCheckpoint(), + InactivityScores: b.inactivityScores(), + CurrentSyncCommittee: b.currentSyncCommittee(), + NextSyncCommittee: b.nextSyncCommittee(), + LatestExecutionPayloadHeader: b.latestExecutionPayloadHeader(), } } @@ -112,16 +113,15 @@ func (b *BeaconState) MarshalSSZ() ([]byte, error) { if !b.hasInnerState() { return nil, errors.New("nil beacon state") } - //TODO: Blocked by https://github.com/ferranbt/fastssz/pull/65 - return []byte{}, nil + return b.state.MarshalSSZ() } -// ProtobufBeaconState transforms an input into beacon state hard fork 1 in the form of protobuf. +// ProtobufBeaconState transforms an input into beacon state Merge in the form of protobuf. // Error is returned if the input is not type protobuf beacon state. -func ProtobufBeaconState(s interface{}) (*ethpb.BeaconStateAltair, error) { - pbState, ok := s.(*ethpb.BeaconStateAltair) +func ProtobufBeaconState(s interface{}) (*ethpb.BeaconStateMerge, error) { + pbState, ok := s.(*ethpb.BeaconStateMerge) if !ok { - return nil, errors.New("input is not type pb.BeaconStateAltair") + return nil, errors.New("input is not type pb.BeaconStateMerge") } return pbState, nil } diff --git a/beacon-chain/state/v3/getters_test.go b/beacon-chain/state/v3/getters_test.go index 6cf112dbcb..c96bb51b27 100644 --- a/beacon-chain/state/v3/getters_test.go +++ b/beacon-chain/state/v3/getters_test.go @@ -5,7 +5,9 @@ import ( "sync" "testing" + types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" "github.com/prysmaticlabs/prysm/testing/require" ) @@ -86,3 +88,105 @@ func TestNilState_NoPanic(t *testing.T) { _, err = st.NextSyncCommittee() _ = err } + +func TestBeaconState_ValidatorByPubkey(t *testing.T) { + keyCreator := func(input []byte) [48]byte { + nKey := [48]byte{} + copy(nKey[:1], input) + return nKey + } + + tests := []struct { + name string + modifyFunc func(b *BeaconState, k [48]byte) + exists bool + expectedIdx types.ValidatorIndex + largestIdxInSet types.ValidatorIndex + }{ + { + name: "retrieve validator", + modifyFunc: func(b *BeaconState, key [48]byte) { + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + }, + exists: true, + expectedIdx: 0, + }, + { + name: "retrieve validator with multiple validators from the start", + modifyFunc: func(b *BeaconState, key [48]byte) { + key1 := keyCreator([]byte{'C'}) + key2 := keyCreator([]byte{'D'}) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key1[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key2[:]})) + }, + exists: true, + expectedIdx: 0, + }, + { + name: "retrieve validator with multiple validators", + modifyFunc: func(b *BeaconState, key [48]byte) { + key1 := keyCreator([]byte{'C'}) + key2 := keyCreator([]byte{'D'}) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key1[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key2[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + }, + exists: true, + expectedIdx: 2, + }, + { + name: "retrieve validator with multiple validators from the start with shared state", + modifyFunc: func(b *BeaconState, key [48]byte) { + key1 := keyCreator([]byte{'C'}) + key2 := keyCreator([]byte{'D'}) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + _ = b.Copy() + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key1[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key2[:]})) + }, + exists: true, + expectedIdx: 0, + }, + { + name: "retrieve validator with multiple validators with shared state", + modifyFunc: func(b *BeaconState, key [48]byte) { + key1 := keyCreator([]byte{'C'}) + key2 := keyCreator([]byte{'D'}) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key1[:]})) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key2[:]})) + n := b.Copy() + // Append to another state + assert.NoError(t, n.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + + }, + exists: false, + expectedIdx: 0, + }, + { + name: "retrieve validator with multiple validators with shared state at boundary", + modifyFunc: func(b *BeaconState, key [48]byte) { + key1 := keyCreator([]byte{'C'}) + assert.NoError(t, b.AppendValidator(ðpb.Validator{PublicKey: key1[:]})) + n := b.Copy() + // Append to another state + assert.NoError(t, n.AppendValidator(ðpb.Validator{PublicKey: key[:]})) + + }, + exists: false, + expectedIdx: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, err := InitializeFromProto(ðpb.BeaconStateMerge{}) + require.NoError(t, err) + nKey := keyCreator([]byte{'A'}) + tt.modifyFunc(s, nKey) + idx, ok := s.ValidatorIndexByPubkey(nKey) + assert.Equal(t, tt.exists, ok) + assert.Equal(t, tt.expectedIdx, idx) + }) + } +} diff --git a/beacon-chain/state/v3/setters_test.go b/beacon-chain/state/v3/setters_test.go new file mode 100644 index 0000000000..c5d9f6d40f --- /dev/null +++ b/beacon-chain/state/v3/setters_test.go @@ -0,0 +1,185 @@ +package v3 + +import ( + "context" + "strconv" + "testing" + + types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/go-bitfield" + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/params" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" + eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" + "github.com/prysmaticlabs/prysm/testing/require" +) + +func TestAppendBeyondIndicesLimit(t *testing.T) { + zeroHash := params.BeaconConfig().ZeroHash + mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockblockRoots); i++ { + mockblockRoots[i] = zeroHash[:] + } + + mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockstateRoots); i++ { + mockstateRoots[i] = zeroHash[:] + } + mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector) + for i := 0; i < len(mockrandaoMixes); i++ { + mockrandaoMixes[i] = zeroHash[:] + } + payload := ðpb.ExecutionPayloadHeader{ + ParentHash: make([]byte, 32), + FeeRecipient: make([]byte, 20), + StateRoot: make([]byte, 32), + ReceiptRoot: make([]byte, 32), + LogsBloom: make([]byte, 256), + Random: make([]byte, 32), + BaseFeePerGas: make([]byte, 32), + BlockHash: make([]byte, 32), + TransactionsRoot: make([]byte, 32), + } + st, err := InitializeFromProto(ðpb.BeaconStateMerge{ + Slot: 1, + CurrentEpochParticipation: []byte{}, + PreviousEpochParticipation: []byte{}, + Validators: []*eth.Validator{}, + Eth1Data: ð.Eth1Data{}, + BlockRoots: mockblockRoots, + StateRoots: mockstateRoots, + RandaoMixes: mockrandaoMixes, + LatestExecutionPayloadHeader: payload, + }) + require.NoError(t, err) + _, err = st.HashTreeRoot(context.Background()) + require.NoError(t, err) + for i := stateTypes.FieldIndex(0); i < stateTypes.FieldIndex(params.BeaconConfig().BeaconStateMergeFieldCount); i++ { + st.dirtyFields[i] = true + } + _, err = st.HashTreeRoot(context.Background()) + require.NoError(t, err) + for i := 0; i < 10; i++ { + assert.NoError(t, st.AppendValidator(ð.Validator{})) + } + assert.Equal(t, false, st.rebuildTrie[validators]) + assert.NotEqual(t, len(st.dirtyIndices[validators]), 0) + + for i := 0; i < indicesLimit; i++ { + assert.NoError(t, st.AppendValidator(ð.Validator{})) + } + assert.Equal(t, true, st.rebuildTrie[validators]) + assert.Equal(t, len(st.dirtyIndices[validators]), 0) +} + +func TestBeaconState_AppendBalanceWithTrie(t *testing.T) { + count := uint64(100) + vals := make([]*ethpb.Validator, 0, count) + bals := make([]uint64, 0, count) + for i := uint64(1); i < count; i++ { + someRoot := [32]byte{} + someKey := [48]byte{} + copy(someRoot[:], strconv.Itoa(int(i))) + copy(someKey[:], strconv.Itoa(int(i))) + vals = append(vals, ðpb.Validator{ + PublicKey: someKey[:], + WithdrawalCredentials: someRoot[:], + EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance, + Slashed: false, + ActivationEligibilityEpoch: 1, + ActivationEpoch: 1, + ExitEpoch: 1, + WithdrawableEpoch: 1, + }) + bals = append(bals, params.BeaconConfig().MaxEffectiveBalance) + } + zeroHash := params.BeaconConfig().ZeroHash + mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockblockRoots); i++ { + mockblockRoots[i] = zeroHash[:] + } + + mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockstateRoots); i++ { + mockstateRoots[i] = zeroHash[:] + } + mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector) + for i := 0; i < len(mockrandaoMixes); i++ { + mockrandaoMixes[i] = zeroHash[:] + } + var pubKeys [][]byte + for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ { + pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength)) + } + payload := ðpb.ExecutionPayloadHeader{ + ParentHash: make([]byte, 32), + FeeRecipient: make([]byte, 20), + StateRoot: make([]byte, 32), + ReceiptRoot: make([]byte, 32), + LogsBloom: make([]byte, 256), + Random: make([]byte, 32), + BaseFeePerGas: make([]byte, 32), + BlockHash: make([]byte, 32), + TransactionsRoot: make([]byte, 32), + } + st, err := InitializeFromProto(ðpb.BeaconStateMerge{ + Slot: 1, + GenesisValidatorsRoot: make([]byte, 32), + Fork: ðpb.Fork{ + PreviousVersion: make([]byte, 4), + CurrentVersion: make([]byte, 4), + Epoch: 0, + }, + LatestBlockHeader: ðpb.BeaconBlockHeader{ + ParentRoot: make([]byte, 32), + StateRoot: make([]byte, 32), + BodyRoot: make([]byte, 32), + }, + CurrentEpochParticipation: []byte{}, + PreviousEpochParticipation: []byte{}, + Validators: vals, + Balances: bals, + Eth1Data: ð.Eth1Data{ + DepositRoot: make([]byte, 32), + BlockHash: make([]byte, 32), + }, + BlockRoots: mockblockRoots, + StateRoots: mockstateRoots, + RandaoMixes: mockrandaoMixes, + JustificationBits: bitfield.NewBitvector4(), + PreviousJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + CurrentJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + FinalizedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector), + CurrentSyncCommittee: ðpb.SyncCommittee{ + Pubkeys: pubKeys, + AggregatePubkey: make([]byte, 48), + }, + NextSyncCommittee: ðpb.SyncCommittee{ + Pubkeys: pubKeys, + AggregatePubkey: make([]byte, 48), + }, + LatestExecutionPayloadHeader: payload, + }) + assert.NoError(t, err) + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + + for i := 0; i < 100; i++ { + if i%2 == 0 { + assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000)) + } + if i%3 == 0 { + assert.NoError(t, st.AppendBalance(1000)) + } + } + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances]) + wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances) + assert.NoError(t, err) + assert.Equal(t, wantedRt, newRt, "state roots are unequal") +} diff --git a/beacon-chain/state/v3/state_trie.go b/beacon-chain/state/v3/state_trie.go index 227c7456c5..813d27c0c2 100644 --- a/beacon-chain/state/v3/state_trie.go +++ b/beacon-chain/state/v3/state_trie.go @@ -1,14 +1,24 @@ package v3 import ( + "context" + "runtime" + "sort" + "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prysmaticlabs/prysm/beacon-chain/state" "github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/beacon-chain/state/types" "github.com/prysmaticlabs/prysm/config/params" + "github.com/prysmaticlabs/prysm/container/slice" + "github.com/prysmaticlabs/prysm/crypto/hash" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" + "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "go.opencensus.io/trace" "google.golang.org/protobuf/proto" ) @@ -31,7 +41,7 @@ func InitializeFromProtoUnsafe(st *ethpb.BeaconStateMerge) (*BeaconState, error) return nil, errors.New("received nil state") } - fieldCount := params.BeaconConfig().BeaconStateAltairFieldCount + fieldCount := params.BeaconConfig().BeaconStateMergeFieldCount b := &BeaconState{ state: st, dirtyFields: make(map[types.FieldIndex]bool, fieldCount), @@ -65,7 +75,319 @@ func InitializeFromProtoUnsafe(st *ethpb.BeaconStateMerge) (*BeaconState, error) b.sharedFieldReferences[balances] = stateutil.NewRef(1) b.sharedFieldReferences[inactivityScores] = stateutil.NewRef(1) // New in Altair. b.sharedFieldReferences[historicalRoots] = stateutil.NewRef(1) - + b.sharedFieldReferences[latestExecutionPayloadHeader] = stateutil.NewRef(1) // New in Merge. stateCount.Inc() return b, nil } + +// Copy returns a deep copy of the beacon state. +func (b *BeaconState) Copy() state.BeaconState { + if !b.hasInnerState() { + return nil + } + + b.lock.RLock() + defer b.lock.RUnlock() + fieldCount := params.BeaconConfig().BeaconStateMergeFieldCount + + dst := &BeaconState{ + state: ðpb.BeaconStateMerge{ + // Primitive types, safe to copy. + GenesisTime: b.state.GenesisTime, + Slot: b.state.Slot, + Eth1DepositIndex: b.state.Eth1DepositIndex, + + // Large arrays, infrequently changed, constant size. + RandaoMixes: b.state.RandaoMixes, + StateRoots: b.state.StateRoots, + BlockRoots: b.state.BlockRoots, + Slashings: b.state.Slashings, + Eth1DataVotes: b.state.Eth1DataVotes, + + // Large arrays, increases over time. + Validators: b.state.Validators, + Balances: b.state.Balances, + HistoricalRoots: b.state.HistoricalRoots, + PreviousEpochParticipation: b.state.PreviousEpochParticipation, + CurrentEpochParticipation: b.state.CurrentEpochParticipation, + InactivityScores: b.state.InactivityScores, + + // Everything else, too small to be concerned about, constant size. + Fork: b.fork(), + LatestBlockHeader: b.latestBlockHeader(), + Eth1Data: b.eth1Data(), + JustificationBits: b.justificationBits(), + PreviousJustifiedCheckpoint: b.previousJustifiedCheckpoint(), + CurrentJustifiedCheckpoint: b.currentJustifiedCheckpoint(), + FinalizedCheckpoint: b.finalizedCheckpoint(), + GenesisValidatorsRoot: b.genesisValidatorRoot(), + CurrentSyncCommittee: b.currentSyncCommittee(), + NextSyncCommittee: b.nextSyncCommittee(), + LatestExecutionPayloadHeader: b.latestExecutionPayloadHeader(), + }, + dirtyFields: make(map[types.FieldIndex]bool, fieldCount), + dirtyIndices: make(map[types.FieldIndex][]uint64, fieldCount), + rebuildTrie: make(map[types.FieldIndex]bool, fieldCount), + sharedFieldReferences: make(map[types.FieldIndex]*stateutil.Reference, 11), + stateFieldLeaves: make(map[types.FieldIndex]*fieldtrie.FieldTrie, fieldCount), + + // Copy on write validator index map. + valMapHandler: b.valMapHandler, + } + + for field, ref := range b.sharedFieldReferences { + ref.AddRef() + dst.sharedFieldReferences[field] = ref + } + + // Increment ref for validator map + b.valMapHandler.AddRef() + + for i := range b.dirtyFields { + dst.dirtyFields[i] = true + } + + for i := range b.dirtyIndices { + indices := make([]uint64, len(b.dirtyIndices[i])) + copy(indices, b.dirtyIndices[i]) + dst.dirtyIndices[i] = indices + } + + for i := range b.rebuildTrie { + dst.rebuildTrie[i] = true + } + + for fldIdx, fieldTrie := range b.stateFieldLeaves { + dst.stateFieldLeaves[fldIdx] = fieldTrie + if fieldTrie.FieldReference() != nil { + fieldTrie.Lock() + fieldTrie.FieldReference().AddRef() + fieldTrie.Unlock() + } + } + + if b.merkleLayers != nil { + dst.merkleLayers = make([][][]byte, len(b.merkleLayers)) + for i, layer := range b.merkleLayers { + dst.merkleLayers[i] = make([][]byte, len(layer)) + for j, content := range layer { + dst.merkleLayers[i][j] = make([]byte, len(content)) + copy(dst.merkleLayers[i][j], content) + } + } + } + stateCount.Inc() + // Finalizer runs when dst is being destroyed in garbage collection. + runtime.SetFinalizer(dst, func(b *BeaconState) { + for field, v := range b.sharedFieldReferences { + v.MinusRef() + if b.stateFieldLeaves[field].FieldReference() != nil { + b.stateFieldLeaves[field].FieldReference().MinusRef() + } + } + for i := 0; i < fieldCount; i++ { + field := types.FieldIndex(i) + delete(b.stateFieldLeaves, field) + delete(b.dirtyIndices, field) + delete(b.dirtyFields, field) + delete(b.sharedFieldReferences, field) + delete(b.stateFieldLeaves, field) + } + stateCount.Sub(1) + }) + + return dst +} + +// HashTreeRoot of the beacon state retrieves the Merkle root of the trie +// representation of the beacon state based on the eth2 Simple Serialize specification. +func (b *BeaconState) HashTreeRoot(ctx context.Context) ([32]byte, error) { + _, span := trace.StartSpan(ctx, "BeaconStateMerge.HashTreeRoot") + defer span.End() + + b.lock.Lock() + defer b.lock.Unlock() + + if b.merkleLayers == nil || len(b.merkleLayers) == 0 { + fieldRoots, err := computeFieldRoots(ctx, b.state) + if err != nil { + return [32]byte{}, err + } + layers := stateutil.Merkleize(fieldRoots) + b.merkleLayers = layers + b.dirtyFields = make(map[types.FieldIndex]bool, params.BeaconConfig().BeaconStateMergeFieldCount) + } + + for field := range b.dirtyFields { + root, err := b.rootSelector(ctx, field) + if err != nil { + return [32]byte{}, err + } + b.merkleLayers[0][field] = root[:] + b.recomputeRoot(int(field)) + delete(b.dirtyFields, field) + } + return bytesutil.ToBytes32(b.merkleLayers[len(b.merkleLayers)-1][0]), nil +} + +// FieldReferencesCount returns the reference count held by each field. This +// also includes the field trie held by each field. +func (b *BeaconState) FieldReferencesCount() map[string]uint64 { + refMap := make(map[string]uint64) + b.lock.RLock() + defer b.lock.RUnlock() + for i, f := range b.sharedFieldReferences { + refMap[i.String(b.Version())] = uint64(f.Refs()) + } + for i, f := range b.stateFieldLeaves { + numOfRefs := uint64(f.FieldReference().Refs()) + f.RLock() + if !f.Empty() { + refMap[i.String(b.Version())+"_trie"] = numOfRefs + } + f.RUnlock() + } + return refMap +} + +// IsNil checks if the state and the underlying proto +// object are nil. +func (b *BeaconState) IsNil() bool { + return b == nil || b.state == nil +} + +func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex) ([32]byte, error) { + hasher := hash.CustomSHA256Hasher() + switch field { + case genesisTime: + return ssz.Uint64Root(b.state.GenesisTime), nil + case genesisValidatorRoot: + return bytesutil.ToBytes32(b.state.GenesisValidatorsRoot), nil + case slot: + return ssz.Uint64Root(uint64(b.state.Slot)), nil + case eth1DepositIndex: + return ssz.Uint64Root(b.state.Eth1DepositIndex), nil + case fork: + return ssz.ForkRoot(b.state.Fork) + case latestBlockHeader: + return stateutil.BlockHeaderRoot(b.state.LatestBlockHeader) + case blockRoots: + if b.rebuildTrie[field] { + err := b.resetFieldTrie(field, b.state.BlockRoots, uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[field] = []uint64{} + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(blockRoots, b.state.BlockRoots) + case stateRoots: + if b.rebuildTrie[field] { + err := b.resetFieldTrie(field, b.state.StateRoots, uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[field] = []uint64{} + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(stateRoots, b.state.StateRoots) + case historicalRoots: + return ssz.ByteArrayRootWithLimit(b.state.HistoricalRoots, params.BeaconConfig().HistoricalRootsLimit) + case eth1Data: + return eth1Root(hasher, b.state.Eth1Data) + case eth1DataVotes: + if b.rebuildTrie[field] { + err := b.resetFieldTrie(field, b.state.Eth1DataVotes, uint64(params.BeaconConfig().SlotsPerEpoch.Mul(uint64(params.BeaconConfig().EpochsPerEth1VotingPeriod)))) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[field] = []uint64{} + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(field, b.state.Eth1DataVotes) + case validators: + if b.rebuildTrie[field] { + err := b.resetFieldTrie(field, b.state.Validators, params.BeaconConfig().ValidatorRegistryLimit) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[validators] = []uint64{} + delete(b.rebuildTrie, validators) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(validators, b.state.Validators) + case balances: + return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances) + case randaoMixes: + if b.rebuildTrie[field] { + err := b.resetFieldTrie(field, b.state.RandaoMixes, uint64(params.BeaconConfig().EpochsPerHistoricalVector)) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[field] = []uint64{} + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(randaoMixes, b.state.RandaoMixes) + case slashings: + return ssz.SlashingsRoot(b.state.Slashings) + case previousEpochParticipationBits: + return stateutil.ParticipationBitsRoot(b.state.PreviousEpochParticipation) + case currentEpochParticipationBits: + return stateutil.ParticipationBitsRoot(b.state.CurrentEpochParticipation) + case justificationBits: + return bytesutil.ToBytes32(b.state.JustificationBits), nil + case previousJustifiedCheckpoint: + return ssz.CheckpointRoot(hasher, b.state.PreviousJustifiedCheckpoint) + case currentJustifiedCheckpoint: + return ssz.CheckpointRoot(hasher, b.state.CurrentJustifiedCheckpoint) + case finalizedCheckpoint: + return ssz.CheckpointRoot(hasher, b.state.FinalizedCheckpoint) + case inactivityScores: + return stateutil.Uint64ListRootWithRegistryLimit(b.state.InactivityScores) + case currentSyncCommittee: + return stateutil.SyncCommitteeRoot(b.state.CurrentSyncCommittee) + case nextSyncCommittee: + return stateutil.SyncCommitteeRoot(b.state.NextSyncCommittee) + case latestExecutionPayloadHeader: + return b.state.LatestExecutionPayloadHeader.HashTreeRoot() + } + return [32]byte{}, errors.New("invalid field index provided") +} + +func (b *BeaconState) recomputeFieldTrie(index types.FieldIndex, elements interface{}) ([32]byte, error) { + fTrie := b.stateFieldLeaves[index] + if fTrie.FieldReference().Refs() > 1 { + fTrie.Lock() + defer fTrie.Unlock() + fTrie.FieldReference().MinusRef() + newTrie := fTrie.CopyTrie() + b.stateFieldLeaves[index] = newTrie + fTrie = newTrie + } + // remove duplicate indexes + b.dirtyIndices[index] = slice.SetUint64(b.dirtyIndices[index]) + // sort indexes again + sort.Slice(b.dirtyIndices[index], func(i int, j int) bool { + return b.dirtyIndices[index][i] < b.dirtyIndices[index][j] + }) + root, err := fTrie.RecomputeTrie(b.dirtyIndices[index], elements) + if err != nil { + return [32]byte{}, err + } + b.dirtyIndices[index] = []uint64{} + return root, nil +} + +func (b *BeaconState) resetFieldTrie(index types.FieldIndex, elements interface{}, length uint64) error { + fTrie, err := fieldtrie.NewFieldTrie(index, fieldMap[index], elements, length) + if err != nil { + return err + } + b.stateFieldLeaves[index] = fTrie + b.dirtyIndices[index] = []uint64{} + return nil +} diff --git a/beacon-chain/state/v3/state_trie_test.go b/beacon-chain/state/v3/state_trie_test.go new file mode 100644 index 0000000000..5b72723858 --- /dev/null +++ b/beacon-chain/state/v3/state_trie_test.go @@ -0,0 +1,167 @@ +package v3 + +import ( + "strconv" + "sync" + "testing" + + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + "github.com/prysmaticlabs/prysm/config/features" + "github.com/prysmaticlabs/prysm/config/params" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" + "github.com/prysmaticlabs/prysm/testing/require" +) + +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + +func TestValidatorMap_DistinctCopy(t *testing.T) { + count := uint64(100) + vals := make([]*ethpb.Validator, 0, count) + for i := uint64(1); i < count; i++ { + someRoot := [32]byte{} + someKey := [48]byte{} + copy(someRoot[:], strconv.Itoa(int(i))) + copy(someKey[:], strconv.Itoa(int(i))) + vals = append(vals, ðpb.Validator{ + PublicKey: someKey[:], + WithdrawalCredentials: someRoot[:], + EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance, + Slashed: false, + ActivationEligibilityEpoch: 1, + ActivationEpoch: 1, + ExitEpoch: 1, + WithdrawableEpoch: 1, + }) + } + handler := stateutil.NewValMapHandler(vals) + newHandler := handler.Copy() + wantedPubkey := strconv.Itoa(22) + handler.Set(bytesutil.ToBytes48([]byte(wantedPubkey)), 27) + val1, _ := handler.Get(bytesutil.ToBytes48([]byte(wantedPubkey))) + val2, _ := newHandler.Get(bytesutil.ToBytes48([]byte(wantedPubkey))) + assert.NotEqual(t, val1, val2, "Values are supposed to be unequal due to copy") +} + +func TestInitializeFromProto(t *testing.T) { + type test struct { + name string + state *ethpb.BeaconStateMerge + error string + } + initTests := []test{ + { + name: "nil state", + state: nil, + error: "received nil state", + }, + { + name: "nil validators", + state: ðpb.BeaconStateMerge{ + Slot: 4, + Validators: nil, + }, + }, + { + name: "empty state", + state: ðpb.BeaconStateMerge{}, + }, + } + for _, tt := range initTests { + t.Run(tt.name, func(t *testing.T) { + _, err := InitializeFromProto(tt.state) + if tt.error != "" { + require.ErrorContains(t, tt.error, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestBeaconState_NoDeadlock(t *testing.T) { + count := uint64(100) + vals := make([]*ethpb.Validator, 0, count) + for i := uint64(1); i < count; i++ { + someRoot := [32]byte{} + someKey := [48]byte{} + copy(someRoot[:], strconv.Itoa(int(i))) + copy(someKey[:], strconv.Itoa(int(i))) + vals = append(vals, ðpb.Validator{ + PublicKey: someKey[:], + WithdrawalCredentials: someRoot[:], + EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance, + Slashed: false, + ActivationEligibilityEpoch: 1, + ActivationEpoch: 1, + ExitEpoch: 1, + WithdrawableEpoch: 1, + }) + } + st, err := InitializeFromProtoUnsafe(ðpb.BeaconStateMerge{ + Validators: vals, + }) + assert.NoError(t, err) + + wg := new(sync.WaitGroup) + + wg.Add(1) + go func() { + // Continuously lock and unlock the state + // by acquiring the lock. + for i := 0; i < 1000; i++ { + for _, f := range st.stateFieldLeaves { + f.Lock() + if f.Empty() { + f.InsertFieldLayer(make([][]*[32]byte, 10)) + } + f.Unlock() + f.FieldReference().AddRef() + } + } + wg.Done() + }() + // Constantly read from the offending portion + // of the code to ensure there is no possible + // recursive read locking. + for i := 0; i < 1000; i++ { + go func() { + _ = st.FieldReferencesCount() + }() + } + // Test will not terminate in the event of a deadlock. + wg.Wait() +} + +func TestInitializeFromProtoUnsafe(t *testing.T) { + type test struct { + name string + state *ethpb.BeaconStateMerge + error string + } + initTests := []test{ + { + name: "nil state", + state: nil, + error: "received nil state", + }, + { + name: "nil validators", + state: ðpb.BeaconStateMerge{ + Slot: 4, + Validators: nil, + }, + }, + { + name: "empty state", + state: ðpb.BeaconStateMerge{}, + }, + // TODO: Add full state. Blocked by testutil migration. + } + _ = initTests +}