Deduplicate Copy Functions In State Packages (#9437)

* remove duplicated copy functions

* add all unit tests

Co-authored-by: Nishant Das <nishdas93@gmail.com>
This commit is contained in:
Raul Jordan
2021-08-23 11:53:50 -05:00
committed by GitHub
parent 63e0a4de84
commit 0ee203fc2e
13 changed files with 174 additions and 116 deletions

View File

@@ -4,6 +4,7 @@ package bytesutil
import (
"encoding/binary"
"errors"
"fmt"
"math/bits"
"regexp"
@@ -166,6 +167,20 @@ func ToLowInt64(x []byte) int64 {
return int64(binary.LittleEndian.Uint64(x))
}
// SafeCopyRootAtIndex takes a copy of an 32-byte slice in a slice of byte slices. Returns error if index out of range.
func SafeCopyRootAtIndex(input [][]byte, idx uint64) ([]byte, error) {
if input == nil {
return nil, nil
}
if uint64(len(input)) <= idx {
return nil, fmt.Errorf("index %d out of range", idx)
}
item := make([]byte, 32)
copy(item, input[idx])
return item, nil
}
// SafeCopyBytes will copy and return a non-nil byte array, otherwise it returns nil.
func SafeCopyBytes(cp []byte) []byte {
if cp != nil {
@@ -176,8 +191,8 @@ func SafeCopyBytes(cp []byte) []byte {
return nil
}
// Copy2dBytes will copy and return a non-nil 2d byte array, otherwise it returns nil.
func Copy2dBytes(ary [][]byte) [][]byte {
// SafeCopy2dBytes will copy and return a non-nil 2d byte array, otherwise it returns nil.
func SafeCopy2dBytes(ary [][]byte) [][]byte {
if ary != nil {
copied := make([][]byte, len(ary))
for i, a := range ary {

View File

@@ -1,6 +1,7 @@
package bytesutil_test
import (
"reflect"
"testing"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
@@ -397,3 +398,77 @@ func TestIsHex(t *testing.T) {
assert.Equal(t, tt.b, isHex)
}
}
func TestSafeCopyRootAtIndex(t *testing.T) {
tests := []struct {
name string
input [][]byte
idx uint64
want []byte
wantErr bool
}{
{
name: "index out of range in non-empty slice",
input: [][]byte{{0x1}, {0x2}},
idx: 2,
wantErr: true,
},
{
name: "index out of range in empty slice",
input: [][]byte{},
idx: 0,
wantErr: true,
},
{
name: "nil input",
input: nil,
idx: 3,
want: nil,
},
{
name: "correct copy",
input: [][]byte{{0x1}, {0x2}},
idx: 1,
want: bytesutil.PadTo([]byte{0x2}, 32),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := bytesutil.SafeCopyRootAtIndex(tt.input, tt.idx)
if (err != nil) != tt.wantErr {
t.Errorf("SafeCopyRootAtIndex() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("SafeCopyRootAtIndex() got = %v, want %v", got, tt.want)
}
})
}
}
func TestSafeCopy2dBytes(t *testing.T) {
tests := []struct {
name string
input [][]byte
}{
{
name: "nil input",
input: nil,
},
{
name: "correct copy",
input: [][]byte{{0x1}, {0x2}},
},
{
name: "empty",
input: [][]byte{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := bytesutil.SafeCopy2dBytes(tt.input); !reflect.DeepEqual(got, tt.input) {
t.Errorf("SafeCopy2dBytes() = %v, want %v", got, tt.input)
}
})
}
}

View File

@@ -17,6 +17,19 @@ func CopyETH1Data(data *ethpb.Eth1Data) *ethpb.Eth1Data {
}
}
// CopyPendingAttestationSlice copies the provided slice of pending attestation objects.
func CopyPendingAttestationSlice(input []*ethpb.PendingAttestation) []*ethpb.PendingAttestation {
if input == nil {
return nil
}
res := make([]*ethpb.PendingAttestation, len(input))
for i := 0; i < len(res); i++ {
res[i] = CopyPendingAttestation(input[i])
}
return res
}
// CopyPendingAttestation copies the provided pending attestation object.
func CopyPendingAttestation(att *ethpb.PendingAttestation) *ethpb.PendingAttestation {
if att == nil {
@@ -265,7 +278,7 @@ func CopyDeposit(deposit *ethpb.Deposit) *ethpb.Deposit {
return nil
}
return &ethpb.Deposit{
Proof: bytesutil.Copy2dBytes(deposit.Proof),
Proof: bytesutil.SafeCopy2dBytes(deposit.Proof),
Data: CopyDepositData(deposit.Data),
}
}

View File

@@ -278,6 +278,36 @@ func TestCopySyncAggregate(t *testing.T) {
assert.NotEmpty(t, got, "Copied sync aggregate has empty fields")
}
func TestCopyPendingAttestationSlice(t *testing.T) {
tests := []struct {
name string
input []*ethpb.PendingAttestation
}{
{
name: "nil",
input: nil,
},
{
name: "empty",
input: []*ethpb.PendingAttestation{},
},
{
name: "correct copy",
input: []*ethpb.PendingAttestation{
genPendingAttestation(),
genPendingAttestation(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := CopyPendingAttestationSlice(tt.input); !reflect.DeepEqual(got, tt.input) {
t.Errorf("CopyPendingAttestationSlice() = %v, want %v", got, tt.input)
}
})
}
}
func bytes() []byte {
b := make([]byte, 32)
_, err := rand.Read(b)

View File

@@ -206,7 +206,7 @@ func buildGenesisBeaconState(genesisTime uint64, preState state.BeaconStateAltai
AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength),
}
st.NextSyncCommittee = &ethpb.SyncCommittee{
Pubkeys: bytesutil.Copy2dBytes(pubKeys),
Pubkeys: bytesutil.SafeCopy2dBytes(pubKeys),
AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength),
}

View File

@@ -202,13 +202,13 @@ func VerifyMerkleBranch(root, item []byte, merkleIndex int, proof [][]byte, dept
func (m *SparseMerkleTrie) Copy() *SparseMerkleTrie {
dstBranches := make([][][]byte, len(m.branches))
for i1, srcB1 := range m.branches {
dstBranches[i1] = bytesutil.Copy2dBytes(srcB1)
dstBranches[i1] = bytesutil.SafeCopy2dBytes(srcB1)
}
return &SparseMerkleTrie{
depth: m.depth,
branches: dstBranches,
originalItems: bytesutil.Copy2dBytes(m.originalItems),
originalItems: bytesutil.SafeCopy2dBytes(m.originalItems),
}
}