From 0ee203fc2e507ae1759692eba0d202ff591c91ce Mon Sep 17 00:00:00 2001 From: Raul Jordan Date: Mon, 23 Aug 2021 11:53:50 -0500 Subject: [PATCH] Deduplicate Copy Functions In State Packages (#9437) * remove duplicated copy functions * add all unit tests Co-authored-by: Nishant Das --- beacon-chain/state/v1/getters_attestation.go | 5 +- beacon-chain/state/v1/getters_block.go | 9 ++- beacon-chain/state/v1/getters_checkpoint.go | 7 +- beacon-chain/state/v1/getters_misc.go | 3 +- beacon-chain/state/v1/getters_randao.go | 8 ++- beacon-chain/state/v1/getters_state.go | 56 +-------------- beacon-chain/state/v2/getters.go | 57 +++------------ shared/bytesutil/bytes.go | 19 ++++- shared/bytesutil/bytes_test.go | 75 ++++++++++++++++++++ shared/copyutil/cloners.go | 15 +++- shared/copyutil/cloners_test.go | 30 ++++++++ shared/testutil/altair.go | 2 +- shared/trieutil/sparse_merkle.go | 4 +- 13 files changed, 174 insertions(+), 116 deletions(-) diff --git a/beacon-chain/state/v1/getters_attestation.go b/beacon-chain/state/v1/getters_attestation.go index ee04b4b044..bceff6e15c 100644 --- a/beacon-chain/state/v1/getters_attestation.go +++ b/beacon-chain/state/v1/getters_attestation.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/copyutil" "github.com/prysmaticlabs/prysm/shared/hashutil" "github.com/prysmaticlabs/prysm/shared/htrutils" "github.com/prysmaticlabs/prysm/shared/params" @@ -35,7 +36,7 @@ func (b *BeaconState) previousEpochAttestations() []*ethpb.PendingAttestation { return nil } - return b.safeCopyPendingAttestationSlice(b.state.PreviousEpochAttestations) + return copyutil.CopyPendingAttestationSlice(b.state.PreviousEpochAttestations) } // CurrentEpochAttestations corresponding to blocks on the beacon chain. @@ -60,7 +61,7 @@ func (b *BeaconState) currentEpochAttestations() []*ethpb.PendingAttestation { return nil } - return b.safeCopyPendingAttestationSlice(b.state.CurrentEpochAttestations) + return copyutil.CopyPendingAttestationSlice(b.state.CurrentEpochAttestations) } func (h *stateRootHasher) epochAttestationsRoot(atts []*ethpb.PendingAttestation) ([32]byte, error) { diff --git a/beacon-chain/state/v1/getters_block.go b/beacon-chain/state/v1/getters_block.go index 10c9e45927..8083034c1f 100644 --- a/beacon-chain/state/v1/getters_block.go +++ b/beacon-chain/state/v1/getters_block.go @@ -1,6 +1,9 @@ package v1 -import ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" +import ( + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" +) // LatestBlockHeader stored within the beacon state. func (b *BeaconState) LatestBlockHeader() *ethpb.BeaconBlockHeader { @@ -66,7 +69,7 @@ func (b *BeaconState) blockRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.BlockRoots) + return bytesutil.SafeCopy2dBytes(b.state.BlockRoots) } // BlockRootAtIndex retrieves a specific block root based on an @@ -92,5 +95,5 @@ func (b *BeaconState) blockRootAtIndex(idx uint64) ([]byte, error) { if !b.hasInnerState() { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.BlockRoots, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.BlockRoots, idx) } diff --git a/beacon-chain/state/v1/getters_checkpoint.go b/beacon-chain/state/v1/getters_checkpoint.go index 410dbc1c5e..2919ec043a 100644 --- a/beacon-chain/state/v1/getters_checkpoint.go +++ b/beacon-chain/state/v1/getters_checkpoint.go @@ -6,6 +6,7 @@ import ( types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/go-bitfield" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/copyutil" ) // JustificationBits marking which epochs have been justified in the beacon chain. @@ -60,7 +61,7 @@ func (b *BeaconState) previousJustifiedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.PreviousJustifiedCheckpoint) + return copyutil.CopyCheckpoint(b.state.PreviousJustifiedCheckpoint) } // CurrentJustifiedCheckpoint denoting an epoch and block root. @@ -85,7 +86,7 @@ func (b *BeaconState) currentJustifiedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.CurrentJustifiedCheckpoint) + return copyutil.CopyCheckpoint(b.state.CurrentJustifiedCheckpoint) } // MatchCurrentJustifiedCheckpoint returns true if input justified checkpoint matches @@ -142,7 +143,7 @@ func (b *BeaconState) finalizedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.FinalizedCheckpoint) + return copyutil.CopyCheckpoint(b.state.FinalizedCheckpoint) } // FinalizedCheckpointEpoch returns the epoch value of the finalized checkpoint. diff --git a/beacon-chain/state/v1/getters_misc.go b/beacon-chain/state/v1/getters_misc.go index 945c2b9d60..658d723445 100644 --- a/beacon-chain/state/v1/getters_misc.go +++ b/beacon-chain/state/v1/getters_misc.go @@ -5,6 +5,7 @@ import ( types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/featureconfig" "github.com/prysmaticlabs/prysm/shared/hashutil" "github.com/prysmaticlabs/prysm/shared/htrutils" @@ -150,7 +151,7 @@ func (b *BeaconState) historicalRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.HistoricalRoots) + return bytesutil.SafeCopy2dBytes(b.state.HistoricalRoots) } // balancesLength returns the length of the balances slice. diff --git a/beacon-chain/state/v1/getters_randao.go b/beacon-chain/state/v1/getters_randao.go index d34039943f..307b0d1955 100644 --- a/beacon-chain/state/v1/getters_randao.go +++ b/beacon-chain/state/v1/getters_randao.go @@ -1,5 +1,9 @@ package v1 +import ( + "github.com/prysmaticlabs/prysm/shared/bytesutil" +) + // RandaoMixes of block proposers on the beacon chain. func (b *BeaconState) RandaoMixes() [][]byte { if !b.hasInnerState() { @@ -22,7 +26,7 @@ func (b *BeaconState) randaoMixes() [][]byte { return nil } - return b.safeCopy2DByteSlice(b.state.RandaoMixes) + return bytesutil.SafeCopy2dBytes(b.state.RandaoMixes) } // RandaoMixAtIndex retrieves a specific block root based on an @@ -49,7 +53,7 @@ func (b *BeaconState) randaoMixAtIndex(idx uint64) ([]byte, error) { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.RandaoMixes, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.RandaoMixes, idx) } // RandaoMixesLength returns the length of the randao mixes slice. diff --git a/beacon-chain/state/v1/getters_state.go b/beacon-chain/state/v1/getters_state.go index 82418b3928..b0d68b4168 100644 --- a/beacon-chain/state/v1/getters_state.go +++ b/beacon-chain/state/v1/getters_state.go @@ -1,12 +1,9 @@ package v1 import ( - "fmt" - - "github.com/prysmaticlabs/prysm/shared/copyutil" - "github.com/pkg/errors" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" ) // InnerStateUnsafe returns the pointer value of the underlying @@ -78,7 +75,7 @@ func (b *BeaconState) stateRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.StateRoots) + return bytesutil.SafeCopy2dBytes(b.state.StateRoots) } // StateRootAtIndex retrieves a specific state root based on an @@ -104,7 +101,7 @@ func (b *BeaconState) stateRootAtIndex(idx uint64) ([]byte, error) { if !b.hasInnerState() { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.StateRoots, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.StateRoots, idx) } // MarshalSSZ marshals the underlying beacon state to bytes. @@ -124,50 +121,3 @@ func ProtobufBeaconState(s interface{}) (*ethpb.BeaconState, error) { } return pbState, nil } - -func (b *BeaconState) safeCopy2DByteSlice(input [][]byte) [][]byte { - if input == nil { - return nil - } - - dst := make([][]byte, len(input)) - for i, r := range input { - tmp := make([]byte, len(r)) - copy(tmp, r) - dst[i] = tmp - } - return dst -} - -func (b *BeaconState) safeCopyBytesAtIndex(input [][]byte, idx uint64) ([]byte, error) { - if input == nil { - return nil, nil - } - - if uint64(len(input)) <= idx { - return nil, fmt.Errorf("index %d out of range", idx) - } - root := make([]byte, 32) - copy(root, input[idx]) - return root, nil -} - -func (b *BeaconState) safeCopyPendingAttestationSlice(input []*ethpb.PendingAttestation) []*ethpb.PendingAttestation { - if input == nil { - return nil - } - - res := make([]*ethpb.PendingAttestation, len(input)) - for i := 0; i < len(res); i++ { - res[i] = copyutil.CopyPendingAttestation(input[i]) - } - return res -} - -func (b *BeaconState) safeCopyCheckpoint(input *ethpb.Checkpoint) *ethpb.Checkpoint { - if input == nil { - return nil - } - - return copyutil.CopyCheckpoint(input) -} diff --git a/beacon-chain/state/v2/getters.go b/beacon-chain/state/v2/getters.go index 5ac9c67aa9..b5e32a9f82 100644 --- a/beacon-chain/state/v2/getters.go +++ b/beacon-chain/state/v2/getters.go @@ -288,7 +288,7 @@ func (b *BeaconState) blockRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.BlockRoots) + return bytesutil.SafeCopy2dBytes(b.state.BlockRoots) } // BlockRootAtIndex retrieves a specific block root based on an @@ -314,7 +314,7 @@ func (b *BeaconState) blockRootAtIndex(idx uint64) ([]byte, error) { if !b.hasInnerState() { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.BlockRoots, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.BlockRoots, idx) } // StateRoots kept track of in the beacon state. @@ -338,7 +338,7 @@ func (b *BeaconState) stateRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.StateRoots) + return bytesutil.SafeCopy2dBytes(b.state.StateRoots) } // StateRootAtIndex retrieves a specific state root based on an @@ -364,7 +364,7 @@ func (b *BeaconState) stateRootAtIndex(idx uint64) ([]byte, error) { if !b.hasInnerState() { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.StateRoots, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.StateRoots, idx) } // HistoricalRoots based on epochs stored in the beacon state. @@ -388,7 +388,7 @@ func (b *BeaconState) historicalRoots() [][]byte { if !b.hasInnerState() { return nil } - return b.safeCopy2DByteSlice(b.state.HistoricalRoots) + return bytesutil.SafeCopy2dBytes(b.state.HistoricalRoots) } // Eth1Data corresponding to the proof-of-work chain information stored in the beacon state. @@ -742,7 +742,7 @@ func (b *BeaconState) randaoMixes() [][]byte { return nil } - return b.safeCopy2DByteSlice(b.state.RandaoMixes) + return bytesutil.SafeCopy2dBytes(b.state.RandaoMixes) } // RandaoMixAtIndex retrieves a specific block root based on an @@ -769,7 +769,7 @@ func (b *BeaconState) randaoMixAtIndex(idx uint64) ([]byte, error) { return nil, ErrNilInnerState } - return b.safeCopyBytesAtIndex(b.state.RandaoMixes, idx) + return bytesutil.SafeCopyRootAtIndex(b.state.RandaoMixes, idx) } // RandaoMixesLength returns the length of the randao mixes slice. @@ -882,7 +882,7 @@ func (b *BeaconState) previousJustifiedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.PreviousJustifiedCheckpoint) + return copyutil.CopyCheckpoint(b.state.PreviousJustifiedCheckpoint) } // CurrentJustifiedCheckpoint denoting an epoch and block root. @@ -907,7 +907,7 @@ func (b *BeaconState) currentJustifiedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.CurrentJustifiedCheckpoint) + return copyutil.CopyCheckpoint(b.state.CurrentJustifiedCheckpoint) } // MatchCurrentJustifiedCheckpoint returns true if input justified checkpoint matches @@ -964,7 +964,7 @@ func (b *BeaconState) finalizedCheckpoint() *ethpb.Checkpoint { return nil } - return b.safeCopyCheckpoint(b.state.FinalizedCheckpoint) + return copyutil.CopyCheckpoint(b.state.FinalizedCheckpoint) } // FinalizedCheckpointEpoch returns the epoch value of the finalized checkpoint. @@ -1115,41 +1115,6 @@ func (b *BeaconState) InactivityScores() ([]uint64, error) { return b.inactivityScores(), nil } -func (b *BeaconState) safeCopy2DByteSlice(input [][]byte) [][]byte { - if input == nil { - return nil - } - - dst := make([][]byte, len(input)) - for i, r := range input { - tmp := make([]byte, len(r)) - copy(tmp, r) - dst[i] = tmp - } - return dst -} - -func (b *BeaconState) safeCopyBytesAtIndex(input [][]byte, idx uint64) ([]byte, error) { - if input == nil { - return nil, nil - } - - if uint64(len(input)) <= idx { - return nil, fmt.Errorf("index %d out of range", idx) - } - root := make([]byte, 32) - copy(root, input[idx]) - return root, nil -} - -func (b *BeaconState) safeCopyCheckpoint(input *ethpb.Checkpoint) *ethpb.Checkpoint { - if input == nil { - return nil - } - - return copyutil.CopyCheckpoint(input) -} - // MarshalSSZ marshals the underlying beacon state to bytes. func (b *BeaconState) MarshalSSZ() ([]byte, error) { if !b.hasInnerState() { @@ -1181,7 +1146,7 @@ func CopySyncCommittee(data *ethpb.SyncCommittee) *ethpb.SyncCommittee { return nil } return ðpb.SyncCommittee{ - Pubkeys: bytesutil.Copy2dBytes(data.Pubkeys), + Pubkeys: bytesutil.SafeCopy2dBytes(data.Pubkeys), AggregatePubkey: bytesutil.SafeCopyBytes(data.AggregatePubkey), } } diff --git a/shared/bytesutil/bytes.go b/shared/bytesutil/bytes.go index 30fa139770..0e68daa994 100644 --- a/shared/bytesutil/bytes.go +++ b/shared/bytesutil/bytes.go @@ -4,6 +4,7 @@ package bytesutil import ( "encoding/binary" "errors" + "fmt" "math/bits" "regexp" @@ -166,6 +167,20 @@ func ToLowInt64(x []byte) int64 { return int64(binary.LittleEndian.Uint64(x)) } +// SafeCopyRootAtIndex takes a copy of an 32-byte slice in a slice of byte slices. Returns error if index out of range. +func SafeCopyRootAtIndex(input [][]byte, idx uint64) ([]byte, error) { + if input == nil { + return nil, nil + } + + if uint64(len(input)) <= idx { + return nil, fmt.Errorf("index %d out of range", idx) + } + item := make([]byte, 32) + copy(item, input[idx]) + return item, nil +} + // SafeCopyBytes will copy and return a non-nil byte array, otherwise it returns nil. func SafeCopyBytes(cp []byte) []byte { if cp != nil { @@ -176,8 +191,8 @@ func SafeCopyBytes(cp []byte) []byte { return nil } -// Copy2dBytes will copy and return a non-nil 2d byte array, otherwise it returns nil. -func Copy2dBytes(ary [][]byte) [][]byte { +// SafeCopy2dBytes will copy and return a non-nil 2d byte array, otherwise it returns nil. +func SafeCopy2dBytes(ary [][]byte) [][]byte { if ary != nil { copied := make([][]byte, len(ary)) for i, a := range ary { diff --git a/shared/bytesutil/bytes_test.go b/shared/bytesutil/bytes_test.go index 62ae3bad95..195d59f1ce 100644 --- a/shared/bytesutil/bytes_test.go +++ b/shared/bytesutil/bytes_test.go @@ -1,6 +1,7 @@ package bytesutil_test import ( + "reflect" "testing" "github.com/prysmaticlabs/prysm/shared/bytesutil" @@ -397,3 +398,77 @@ func TestIsHex(t *testing.T) { assert.Equal(t, tt.b, isHex) } } + +func TestSafeCopyRootAtIndex(t *testing.T) { + tests := []struct { + name string + input [][]byte + idx uint64 + want []byte + wantErr bool + }{ + { + name: "index out of range in non-empty slice", + input: [][]byte{{0x1}, {0x2}}, + idx: 2, + wantErr: true, + }, + { + name: "index out of range in empty slice", + input: [][]byte{}, + idx: 0, + wantErr: true, + }, + { + name: "nil input", + input: nil, + idx: 3, + want: nil, + }, + { + name: "correct copy", + input: [][]byte{{0x1}, {0x2}}, + idx: 1, + want: bytesutil.PadTo([]byte{0x2}, 32), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := bytesutil.SafeCopyRootAtIndex(tt.input, tt.idx) + if (err != nil) != tt.wantErr { + t.Errorf("SafeCopyRootAtIndex() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SafeCopyRootAtIndex() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSafeCopy2dBytes(t *testing.T) { + tests := []struct { + name string + input [][]byte + }{ + { + name: "nil input", + input: nil, + }, + { + name: "correct copy", + input: [][]byte{{0x1}, {0x2}}, + }, + { + name: "empty", + input: [][]byte{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := bytesutil.SafeCopy2dBytes(tt.input); !reflect.DeepEqual(got, tt.input) { + t.Errorf("SafeCopy2dBytes() = %v, want %v", got, tt.input) + } + }) + } +} diff --git a/shared/copyutil/cloners.go b/shared/copyutil/cloners.go index 8b50867ae3..781b1d91eb 100644 --- a/shared/copyutil/cloners.go +++ b/shared/copyutil/cloners.go @@ -17,6 +17,19 @@ func CopyETH1Data(data *ethpb.Eth1Data) *ethpb.Eth1Data { } } +// CopyPendingAttestationSlice copies the provided slice of pending attestation objects. +func CopyPendingAttestationSlice(input []*ethpb.PendingAttestation) []*ethpb.PendingAttestation { + if input == nil { + return nil + } + + res := make([]*ethpb.PendingAttestation, len(input)) + for i := 0; i < len(res); i++ { + res[i] = CopyPendingAttestation(input[i]) + } + return res +} + // CopyPendingAttestation copies the provided pending attestation object. func CopyPendingAttestation(att *ethpb.PendingAttestation) *ethpb.PendingAttestation { if att == nil { @@ -265,7 +278,7 @@ func CopyDeposit(deposit *ethpb.Deposit) *ethpb.Deposit { return nil } return ðpb.Deposit{ - Proof: bytesutil.Copy2dBytes(deposit.Proof), + Proof: bytesutil.SafeCopy2dBytes(deposit.Proof), Data: CopyDepositData(deposit.Data), } } diff --git a/shared/copyutil/cloners_test.go b/shared/copyutil/cloners_test.go index ffc4789a1d..592fb8e694 100644 --- a/shared/copyutil/cloners_test.go +++ b/shared/copyutil/cloners_test.go @@ -278,6 +278,36 @@ func TestCopySyncAggregate(t *testing.T) { assert.NotEmpty(t, got, "Copied sync aggregate has empty fields") } +func TestCopyPendingAttestationSlice(t *testing.T) { + tests := []struct { + name string + input []*ethpb.PendingAttestation + }{ + { + name: "nil", + input: nil, + }, + { + name: "empty", + input: []*ethpb.PendingAttestation{}, + }, + { + name: "correct copy", + input: []*ethpb.PendingAttestation{ + genPendingAttestation(), + genPendingAttestation(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CopyPendingAttestationSlice(tt.input); !reflect.DeepEqual(got, tt.input) { + t.Errorf("CopyPendingAttestationSlice() = %v, want %v", got, tt.input) + } + }) + } +} + func bytes() []byte { b := make([]byte, 32) _, err := rand.Read(b) diff --git a/shared/testutil/altair.go b/shared/testutil/altair.go index 4e1d516d8f..64c9ba3320 100644 --- a/shared/testutil/altair.go +++ b/shared/testutil/altair.go @@ -206,7 +206,7 @@ func buildGenesisBeaconState(genesisTime uint64, preState state.BeaconStateAltai AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), } st.NextSyncCommittee = ðpb.SyncCommittee{ - Pubkeys: bytesutil.Copy2dBytes(pubKeys), + Pubkeys: bytesutil.SafeCopy2dBytes(pubKeys), AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), } diff --git a/shared/trieutil/sparse_merkle.go b/shared/trieutil/sparse_merkle.go index 6c1cba51f0..45b40175b6 100644 --- a/shared/trieutil/sparse_merkle.go +++ b/shared/trieutil/sparse_merkle.go @@ -202,13 +202,13 @@ func VerifyMerkleBranch(root, item []byte, merkleIndex int, proof [][]byte, dept func (m *SparseMerkleTrie) Copy() *SparseMerkleTrie { dstBranches := make([][][]byte, len(m.branches)) for i1, srcB1 := range m.branches { - dstBranches[i1] = bytesutil.Copy2dBytes(srcB1) + dstBranches[i1] = bytesutil.SafeCopy2dBytes(srcB1) } return &SparseMerkleTrie{ depth: m.depth, branches: dstBranches, - originalItems: bytesutil.Copy2dBytes(m.originalItems), + originalItems: bytesutil.SafeCopy2dBytes(m.originalItems), } }