diff --git a/beacon-chain/state/fieldtrie/BUILD.bazel b/beacon-chain/state/fieldtrie/BUILD.bazel index 180065daab..fcea708041 100644 --- a/beacon-chain/state/fieldtrie/BUILD.bazel +++ b/beacon-chain/state/fieldtrie/BUILD.bazel @@ -12,6 +12,8 @@ go_library( "//beacon-chain/state/stateutil:go_default_library", "//beacon-chain/state/types:go_default_library", "//crypto/hash:go_default_library", + "//encoding/bytesutil:go_default_library", + "//encoding/ssz:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//runtime/version:go_default_library", "@com_github_pkg_errors//:go_default_library", @@ -26,6 +28,7 @@ go_test( ], embed = [":go_default_library"], deps = [ + "//beacon-chain/state/stateutil:go_default_library", "//beacon-chain/state/types:go_default_library", "//beacon-chain/state/v1:go_default_library", "//config/params:go_default_library", diff --git a/beacon-chain/state/fieldtrie/field_trie.go b/beacon-chain/state/fieldtrie/field_trie.go index f002dc6ed5..4f441668b7 100644 --- a/beacon-chain/state/fieldtrie/field_trie.go +++ b/beacon-chain/state/fieldtrie/field_trie.go @@ -18,6 +18,7 @@ type FieldTrie struct { field types.FieldIndex dataType types.DataType length uint64 + numOfElems int } // NewFieldTrie is the constructor for the field trie data structure. It creates the corresponding @@ -26,18 +27,19 @@ type FieldTrie struct { func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) (*FieldTrie, error) { if elements == nil { return &FieldTrie{ - field: field, - dataType: dataType, - reference: stateutil.NewRef(1), - RWMutex: new(sync.RWMutex), - length: length, + field: field, + dataType: dataType, + reference: stateutil.NewRef(1), + RWMutex: new(sync.RWMutex), + length: length, + numOfElems: 0, }, nil } fieldRoots, err := fieldConverters(field, []uint64{}, elements, true) if err != nil { return nil, err } - if err := validateElements(field, elements, length); err != nil { + if err := validateElements(field, dataType, elements, length); err != nil { return nil, err } switch dataType { @@ -53,8 +55,9 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte reference: stateutil.NewRef(1), RWMutex: new(sync.RWMutex), length: length, + numOfElems: reflect.ValueOf(elements).Len(), }, nil - case types.CompositeArray: + case types.CompositeArray, types.CompressedArray: return &FieldTrie{ fieldLayers: stateutil.ReturnTrieLayerVariable(fieldRoots, length), field: field, @@ -62,6 +65,7 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte reference: stateutil.NewRef(1), RWMutex: new(sync.RWMutex), length: length, + numOfElems: reflect.ValueOf(elements).Len(), }, nil default: return nil, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(dataType).Name()) @@ -92,13 +96,40 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b if err != nil { return [32]byte{}, err } + f.numOfElems = reflect.ValueOf(elements).Len() return fieldRoot, nil case types.CompositeArray: fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(fieldRoots, indices, f.fieldLayers) if err != nil { return [32]byte{}, err } + f.numOfElems = reflect.ValueOf(elements).Len() return stateutil.AddInMixin(fieldRoot, uint64(len(f.fieldLayers[0]))) + case types.CompressedArray: + numOfElems, err := f.field.ElemsInChunk() + if err != nil { + return [32]byte{}, err + } + // We remove the duplicates here in order to prevent + // duplicated insertions into the trie. + newIndices := []uint64{} + indexExists := make(map[uint64]bool) + newRoots := make([][32]byte, 0, len(fieldRoots)/int(numOfElems)) + for i, idx := range indices { + startIdx := idx / numOfElems + if indexExists[startIdx] { + continue + } + newIndices = append(newIndices, startIdx) + indexExists[startIdx] = true + newRoots = append(newRoots, fieldRoots[i]) + } + fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(newRoots, newIndices, f.fieldLayers) + if err != nil { + return [32]byte{}, err + } + f.numOfElems = reflect.ValueOf(elements).Len() + return stateutil.AddInMixin(fieldRoot, uint64(f.numOfElems)) default: return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name()) } @@ -109,11 +140,12 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b func (f *FieldTrie) CopyTrie() *FieldTrie { if f.fieldLayers == nil { return &FieldTrie{ - field: f.field, - dataType: f.dataType, - reference: stateutil.NewRef(1), - RWMutex: new(sync.RWMutex), - length: f.length, + field: f.field, + dataType: f.dataType, + reference: stateutil.NewRef(1), + RWMutex: new(sync.RWMutex), + length: f.length, + numOfElems: f.numOfElems, } } dstFieldTrie := make([][]*[32]byte, len(f.fieldLayers)) @@ -128,6 +160,7 @@ func (f *FieldTrie) CopyTrie() *FieldTrie { reference: stateutil.NewRef(1), RWMutex: new(sync.RWMutex), length: f.length, + numOfElems: f.numOfElems, } } @@ -139,6 +172,9 @@ func (f *FieldTrie) TrieRoot() ([32]byte, error) { case types.CompositeArray: trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0] return stateutil.AddInMixin(trieRoot, uint64(len(f.fieldLayers[0]))) + case types.CompressedArray: + trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0] + return stateutil.AddInMixin(trieRoot, uint64(f.numOfElems)) default: return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name()) } diff --git a/beacon-chain/state/fieldtrie/field_trie_helpers.go b/beacon-chain/state/fieldtrie/field_trie_helpers.go index 77aa4ac167..5a7c6b687e 100644 --- a/beacon-chain/state/fieldtrie/field_trie_helpers.go +++ b/beacon-chain/state/fieldtrie/field_trie_helpers.go @@ -1,6 +1,7 @@ package fieldtrie import ( + "encoding/binary" "fmt" "reflect" @@ -8,20 +9,37 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/beacon-chain/state/types" "github.com/prysmaticlabs/prysm/crypto/hash" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" + "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/runtime/version" ) func (f *FieldTrie) validateIndices(idxs []uint64) error { + length := f.length + if f.dataType == types.CompressedArray { + comLength, err := f.field.ElemsInChunk() + if err != nil { + return err + } + length *= comLength + } for _, idx := range idxs { - if idx >= f.length { - return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, f.length) + if idx >= length { + return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, length) } } return nil } -func validateElements(field types.FieldIndex, elements interface{}, length uint64) error { +func validateElements(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) error { + if dataType == types.CompressedArray { + comLength, err := field.ElemsInChunk() + if err != nil { + return err + } + length *= comLength + } val := reflect.ValueOf(elements) if val.Len() > int(length) { return errors.Errorf("elements length is larger than expected for field %s: %d > %d", field.String(version.Phase0), val.Len(), length) @@ -38,21 +56,21 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac return nil, errors.Errorf("Wanted type of %v but got %v", reflect.TypeOf([][]byte{}).Name(), reflect.TypeOf(elements).Name()) } - return stateutil.HandleByteArrays(val, indices, convertAll) + return handleByteArrays(val, indices, convertAll) case types.Eth1DataVotes: val, ok := elements.([]*ethpb.Eth1Data) if !ok { return nil, errors.Errorf("Wanted type of %v but got %v", reflect.TypeOf([]*ethpb.Eth1Data{}).Name(), reflect.TypeOf(elements).Name()) } - return HandleEth1DataSlice(val, indices, convertAll) + return handleEth1DataSlice(val, indices, convertAll) case types.Validators: val, ok := elements.([]*ethpb.Validator) if !ok { return nil, errors.Errorf("Wanted type of %v but got %v", reflect.TypeOf([]*ethpb.Validator{}).Name(), reflect.TypeOf(elements).Name()) } - return stateutil.HandleValidatorSlice(val, indices, convertAll) + return handleValidatorSlice(val, indices, convertAll) case types.PreviousEpochAttestations, types.CurrentEpochAttestations: val, ok := elements.([]*ethpb.PendingAttestation) if !ok { @@ -60,13 +78,87 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac reflect.TypeOf([]*ethpb.PendingAttestation{}).Name(), reflect.TypeOf(elements).Name()) } return handlePendingAttestation(val, indices, convertAll) + case types.Balances: + val, ok := elements.([]uint64) + if !ok { + return nil, errors.Errorf("Wanted type of %v but got %v", + reflect.TypeOf([]uint64{}).Name(), reflect.TypeOf(elements).Name()) + } + return handleBalanceSlice(val, indices, convertAll) default: return [][32]byte{}, errors.Errorf("got unsupported type of %v", reflect.TypeOf(elements).Name()) } } -// HandleEth1DataSlice processes a list of eth1data and indices into the appropriate roots. -func HandleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) { +// handleByteArrays computes and returns byte arrays in a slice of root format. +func handleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) { + length := len(indices) + if convertAll { + length = len(val) + } + roots := make([][32]byte, 0, length) + rootCreator := func(input []byte) { + newRoot := bytesutil.ToBytes32(input) + roots = append(roots, newRoot) + } + if convertAll { + for i := range val { + rootCreator(val[i]) + } + return roots, nil + } + if len(val) > 0 { + for _, idx := range indices { + if idx > uint64(len(val))-1 { + return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val)) + } + rootCreator(val[idx]) + } + } + return roots, nil +} + +// handleValidatorSlice returns the validator indices in a slice of root format. +func handleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) { + length := len(indices) + if convertAll { + length = len(val) + } + roots := make([][32]byte, 0, length) + hasher := hash.CustomSHA256Hasher() + rootCreator := func(input *ethpb.Validator) error { + newRoot, err := stateutil.ValidatorRootWithHasher(hasher, input) + if err != nil { + return err + } + roots = append(roots, newRoot) + return nil + } + if convertAll { + for i := range val { + err := rootCreator(val[i]) + if err != nil { + return nil, err + } + } + return roots, nil + } + if len(val) > 0 { + for _, idx := range indices { + if idx > uint64(len(val))-1 { + return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val)) + } + err := rootCreator(val[idx]) + if err != nil { + return nil, err + } + } + } + return roots, nil +} + +// handleEth1DataSlice processes a list of eth1data and indices into the appropriate roots. +func handleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) { length := len(indices) if convertAll { length = len(val) @@ -141,3 +233,48 @@ func handlePendingAttestation(val []*ethpb.PendingAttestation, indices []uint64, } return roots, nil } + +func handleBalanceSlice(val []uint64, indices []uint64, convertAll bool) ([][32]byte, error) { + if convertAll { + balancesMarshaling := make([][]byte, 0) + for _, b := range val { + balanceBuf := make([]byte, 8) + binary.LittleEndian.PutUint64(balanceBuf, b) + balancesMarshaling = append(balancesMarshaling, balanceBuf) + } + balancesChunks, err := ssz.PackByChunk(balancesMarshaling) + if err != nil { + return [][32]byte{}, errors.Wrap(err, "could not pack balances into chunks") + } + return balancesChunks, nil + } + if len(val) > 0 { + numOfElems, err := types.Balances.ElemsInChunk() + if err != nil { + return nil, err + } + roots := [][32]byte{} + for _, idx := range indices { + // We split the indexes into their relevant groups. Balances + // are compressed according to 4 values -> 1 chunk. + startIdx := idx / numOfElems + startGroup := startIdx * numOfElems + chunk := [32]byte{} + sizeOfElem := len(chunk) / int(numOfElems) + for i, j := 0, startGroup; j < startGroup+numOfElems; i, j = i+sizeOfElem, j+1 { + wantedVal := uint64(0) + // We are adding chunks in sets of 4, if the set is at the edge of the array + // then you will need to zero out the rest of the chunk. Ex : 41 indexes, + // so 41 % 4 = 1 . There are 3 indexes, which do not exist yet but we + // have to add in as a root. These 3 indexes are then given a 'zero' value. + if int(j) < len(val) { + wantedVal = val[j] + } + binary.LittleEndian.PutUint64(chunk[i:i+sizeOfElem], wantedVal) + } + roots = append(roots, chunk) + } + return roots, nil + } + return [][32]byte{}, nil +} diff --git a/beacon-chain/state/fieldtrie/helpers_test.go b/beacon-chain/state/fieldtrie/helpers_test.go index ec9cf5aae2..54666b872a 100644 --- a/beacon-chain/state/fieldtrie/helpers_test.go +++ b/beacon-chain/state/fieldtrie/helpers_test.go @@ -1,8 +1,13 @@ package fieldtrie import ( + "encoding/binary" + "sync" "testing" + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/params" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/testing/assert" ) @@ -17,7 +22,64 @@ func Test_handlePendingAttestation_OutOfRange(t *testing.T) { func Test_handleEth1DataSlice_OutOfRange(t *testing.T) { items := make([]*ethpb.Eth1Data, 1) indices := []uint64{3} - _, err := HandleEth1DataSlice(items, indices, false) + _, err := handleEth1DataSlice(items, indices, false) assert.ErrorContains(t, "index 3 greater than number of items in eth1 data slice 1", err) } + +func Test_handleValidatorSlice_OutOfRange(t *testing.T) { + vals := make([]*ethpb.Validator, 1) + indices := []uint64{3} + _, err := handleValidatorSlice(vals, indices, false) + assert.ErrorContains(t, "index 3 greater than number of validators 1", err) +} + +func TestBalancesSlice_CorrectRoots_All(t *testing.T) { + balances := []uint64{5, 2929, 34, 1291, 354305} + roots, err := handleBalanceSlice(balances, []uint64{}, true) + assert.NoError(t, err) + + root1 := [32]byte{} + binary.LittleEndian.PutUint64(root1[:8], balances[0]) + binary.LittleEndian.PutUint64(root1[8:16], balances[1]) + binary.LittleEndian.PutUint64(root1[16:24], balances[2]) + binary.LittleEndian.PutUint64(root1[24:32], balances[3]) + + root2 := [32]byte{} + binary.LittleEndian.PutUint64(root2[:8], balances[4]) + + assert.DeepEqual(t, roots, [][32]byte{root1, root2}) +} + +func TestBalancesSlice_CorrectRoots_Some(t *testing.T) { + balances := []uint64{5, 2929, 34, 1291, 354305} + roots, err := handleBalanceSlice(balances, []uint64{2, 3}, false) + assert.NoError(t, err) + + root1 := [32]byte{} + binary.LittleEndian.PutUint64(root1[:8], balances[0]) + binary.LittleEndian.PutUint64(root1[8:16], balances[1]) + binary.LittleEndian.PutUint64(root1[16:24], balances[2]) + binary.LittleEndian.PutUint64(root1[24:32], balances[3]) + + // Returns root for each indice(even if duplicated) + assert.DeepEqual(t, roots, [][32]byte{root1, root1}) +} + +func TestValidateIndices_CompressedField(t *testing.T) { + fakeTrie := &FieldTrie{ + RWMutex: new(sync.RWMutex), + reference: stateutil.NewRef(0), + fieldLayers: nil, + field: types.Balances, + dataType: types.CompressedArray, + length: params.BeaconConfig().ValidatorRegistryLimit / 4, + numOfElems: 0, + } + goodIdx := params.BeaconConfig().ValidatorRegistryLimit - 1 + assert.NoError(t, fakeTrie.validateIndices([]uint64{goodIdx})) + + badIdx := goodIdx + 1 + assert.ErrorContains(t, "invalid index for field balances", fakeTrie.validateIndices([]uint64{badIdx})) + +} diff --git a/beacon-chain/state/stateutil/BUILD.bazel b/beacon-chain/state/stateutil/BUILD.bazel index 8d9d2cc194..8b83cb2ca9 100644 --- a/beacon-chain/state/stateutil/BUILD.bazel +++ b/beacon-chain/state/stateutil/BUILD.bazel @@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", srcs = [ - "array_root.go", "block_header_root.go", "eth1_root.go", "participation_bit_root.go", @@ -49,7 +48,6 @@ go_test( "state_root_test.go", "stateutil_test.go", "trie_helpers_test.go", - "validator_root_test.go", ], embed = [":go_default_library"], deps = [ diff --git a/beacon-chain/state/stateutil/array_root.go b/beacon-chain/state/stateutil/array_root.go deleted file mode 100644 index eb07514279..0000000000 --- a/beacon-chain/state/stateutil/array_root.go +++ /dev/null @@ -1,35 +0,0 @@ -package stateutil - -import ( - "fmt" - - "github.com/prysmaticlabs/prysm/encoding/bytesutil" -) - -// HandleByteArrays computes and returns byte arrays in a slice of root format. -func HandleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) { - length := len(indices) - if convertAll { - length = len(val) - } - roots := make([][32]byte, 0, length) - rootCreator := func(input []byte) { - newRoot := bytesutil.ToBytes32(input) - roots = append(roots, newRoot) - } - if convertAll { - for i := range val { - rootCreator(val[i]) - } - return roots, nil - } - if len(val) > 0 { - for _, idx := range indices { - if idx > uint64(len(val))-1 { - return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val)) - } - rootCreator(val[idx]) - } - } - return roots, nil -} diff --git a/beacon-chain/state/stateutil/validator_root.go b/beacon-chain/state/stateutil/validator_root.go index 7564572447..e41fc0d6af 100644 --- a/beacon-chain/state/stateutil/validator_root.go +++ b/beacon-chain/state/stateutil/validator_root.go @@ -2,7 +2,6 @@ package stateutil import ( "encoding/binary" - "fmt" "github.com/pkg/errors" "github.com/prysmaticlabs/prysm/config/params" @@ -127,42 +126,3 @@ func ValidatorEncKey(validator *ethpb.Validator) []byte { return enc } - -// HandleValidatorSlice returns the validator indices in a slice of root format. -func HandleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) { - length := len(indices) - if convertAll { - length = len(val) - } - roots := make([][32]byte, 0, length) - hasher := hash.CustomSHA256Hasher() - rootCreator := func(input *ethpb.Validator) error { - newRoot, err := ValidatorRootWithHasher(hasher, input) - if err != nil { - return err - } - roots = append(roots, newRoot) - return nil - } - if convertAll { - for i := range val { - err := rootCreator(val[i]) - if err != nil { - return nil, err - } - } - return roots, nil - } - if len(val) > 0 { - for _, idx := range indices { - if idx > uint64(len(val))-1 { - return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val)) - } - err := rootCreator(val[idx]) - if err != nil { - return nil, err - } - } - } - return roots, nil -} diff --git a/beacon-chain/state/stateutil/validator_root_test.go b/beacon-chain/state/stateutil/validator_root_test.go deleted file mode 100644 index 9526a062f2..0000000000 --- a/beacon-chain/state/stateutil/validator_root_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package stateutil - -import ( - "testing" - - ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" - "github.com/prysmaticlabs/prysm/testing/assert" -) - -func Test_handleValidatorSlice_OutOfRange(t *testing.T) { - vals := make([]*ethpb.Validator, 1) - indices := []uint64{3} - _, err := HandleValidatorSlice(vals, indices, false) - assert.ErrorContains(t, "index 3 greater than number of validators 1", err) -} diff --git a/beacon-chain/state/types/BUILD.bazel b/beacon-chain/state/types/BUILD.bazel index 43ec2319b1..23a0642300 100644 --- a/beacon-chain/state/types/BUILD.bazel +++ b/beacon-chain/state/types/BUILD.bazel @@ -5,7 +5,10 @@ go_library( srcs = ["types.go"], importpath = "github.com/prysmaticlabs/prysm/beacon-chain/state/types", visibility = ["//beacon-chain:__subpackages__"], - deps = ["//runtime/version:go_default_library"], + deps = [ + "//runtime/version:go_default_library", + "@com_github_pkg_errors//:go_default_library", + ], ) go_test( diff --git a/beacon-chain/state/types/types.go b/beacon-chain/state/types/types.go index 34feb537e3..53c0253471 100644 --- a/beacon-chain/state/types/types.go +++ b/beacon-chain/state/types/types.go @@ -1,6 +1,7 @@ package types import ( + "github.com/pkg/errors" "github.com/prysmaticlabs/prysm/runtime/version" ) @@ -18,6 +19,10 @@ const ( // CompositeArray represents a variable length array with // a non primitive type. CompositeArray + // CompressedArray represents a variable length array which + // can pack multiple elements into a leaf of the underlying + // trie. + CompressedArray ) // String returns the name of the field index. @@ -84,6 +89,17 @@ func (f FieldIndex) String(stateVersion int) string { } } +// ElemsInChunk returns the number of elements in the chunk (number of +// elements that are able to be packed). +func (f FieldIndex) ElemsInChunk() (uint64, error) { + switch f { + case Balances: + return 4, nil + default: + return 0, errors.Errorf("field %d doesn't support element compression", f) + } +} + // Below we define a set of useful enum values for the field // indices of the beacon state. For example, genesisTime is the // 0th field of the beacon state. This is helpful when we are diff --git a/beacon-chain/state/v1/BUILD.bazel b/beacon-chain/state/v1/BUILD.bazel index c0360749df..ca1220a9c5 100644 --- a/beacon-chain/state/v1/BUILD.bazel +++ b/beacon-chain/state/v1/BUILD.bazel @@ -87,6 +87,7 @@ go_test( "//beacon-chain/state:go_default_library", "//beacon-chain/state/stateutil:go_default_library", "//beacon-chain/state/types:go_default_library", + "//config/features:go_default_library", "//config/params:go_default_library", "//encoding/bytesutil:go_default_library", "//proto/prysm/v1alpha1:go_default_library", diff --git a/beacon-chain/state/v1/setters_misc.go b/beacon-chain/state/v1/setters_misc.go index f3cef51867..e6b89626b3 100644 --- a/beacon-chain/state/v1/setters_misc.go +++ b/beacon-chain/state/v1/setters_misc.go @@ -5,6 +5,7 @@ import ( types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/crypto/hash" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "google.golang.org/protobuf/proto" @@ -172,6 +173,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin if b.rebuildTrie[index] { return } + // Exit early if balance trie computation isn't enabled. + if !features.Get().EnableBalanceTrieComputation && index == balances { + return + } totalIndicesLen := len(b.dirtyIndices[index]) + len(indices) if totalIndicesLen > indicesLimit { b.rebuildTrie[index] = true diff --git a/beacon-chain/state/v1/setters_validator.go b/beacon-chain/state/v1/setters_validator.go index e60b179e43..598c278408 100644 --- a/beacon-chain/state/v1/setters_validator.go +++ b/beacon-chain/state/v1/setters_validator.go @@ -103,6 +103,7 @@ func (b *BeaconState) SetBalances(val []uint64) error { b.state.Balances = val b.markFieldAsDirty(balances) + b.rebuildTrie[balances] = true return nil } @@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64 bals[idx] = val b.state.Balances = bals b.markFieldAsDirty(balances) + b.addDirtyIndices(balances, []uint64{uint64(idx)}) return nil } @@ -219,6 +221,8 @@ func (b *BeaconState) AppendBalance(bal uint64) error { } b.state.Balances = append(bals, bal) + balIdx := len(b.state.Balances) - 1 b.markFieldAsDirty(balances) + b.addDirtyIndices(balances, []uint64{uint64(balIdx)}) return nil } diff --git a/beacon-chain/state/v1/state_test.go b/beacon-chain/state/v1/state_test.go index 4d8c90396f..20379ca0db 100644 --- a/beacon-chain/state/v1/state_test.go +++ b/beacon-chain/state/v1/state_test.go @@ -1,10 +1,13 @@ package v1 import ( + "context" "strconv" "sync" "testing" + types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/go-bitfield" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/encoding/bytesutil" @@ -104,3 +107,90 @@ func TestStateTrie_IsNil(t *testing.T) { nonNilState := &BeaconState{state: ðpb.BeaconState{}} assert.Equal(t, false, nonNilState.IsNil()) } + +func TestBeaconState_AppendBalanceWithTrie(t *testing.T) { + count := uint64(100) + vals := make([]*ethpb.Validator, 0, count) + bals := make([]uint64, 0, count) + for i := uint64(1); i < count; i++ { + someRoot := [32]byte{} + someKey := [48]byte{} + copy(someRoot[:], strconv.Itoa(int(i))) + copy(someKey[:], strconv.Itoa(int(i))) + vals = append(vals, ðpb.Validator{ + PublicKey: someKey[:], + WithdrawalCredentials: someRoot[:], + EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance, + Slashed: false, + ActivationEligibilityEpoch: 1, + ActivationEpoch: 1, + ExitEpoch: 1, + WithdrawableEpoch: 1, + }) + bals = append(bals, params.BeaconConfig().MaxEffectiveBalance) + } + zeroHash := params.BeaconConfig().ZeroHash + mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockblockRoots); i++ { + mockblockRoots[i] = zeroHash[:] + } + + mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockstateRoots); i++ { + mockstateRoots[i] = zeroHash[:] + } + mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector) + for i := 0; i < len(mockrandaoMixes); i++ { + mockrandaoMixes[i] = zeroHash[:] + } + var pubKeys [][]byte + for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ { + pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength)) + } + st, err := InitializeFromProto(ðpb.BeaconState{ + Slot: 1, + GenesisValidatorsRoot: make([]byte, 32), + Fork: ðpb.Fork{ + PreviousVersion: make([]byte, 4), + CurrentVersion: make([]byte, 4), + Epoch: 0, + }, + LatestBlockHeader: ðpb.BeaconBlockHeader{ + ParentRoot: make([]byte, 32), + StateRoot: make([]byte, 32), + BodyRoot: make([]byte, 32), + }, + Validators: vals, + Balances: bals, + Eth1Data: ðpb.Eth1Data{ + DepositRoot: make([]byte, 32), + BlockHash: make([]byte, 32), + }, + BlockRoots: mockblockRoots, + StateRoots: mockstateRoots, + RandaoMixes: mockrandaoMixes, + JustificationBits: bitfield.NewBitvector4(), + PreviousJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + CurrentJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + FinalizedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector), + }) + assert.NoError(t, err) + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + + for i := 0; i < 100; i++ { + if i%2 == 0 { + assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000)) + } + if i%3 == 0 { + assert.NoError(t, st.AppendBalance(1000)) + } + } + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances]) + wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances) + assert.NoError(t, err) + assert.Equal(t, wantedRt, newRt, "state roots are unequal") +} diff --git a/beacon-chain/state/v1/state_trie.go b/beacon-chain/state/v1/state_trie.go index 98731b0b5a..2c862ca034 100644 --- a/beacon-chain/state/v1/state_trie.go +++ b/beacon-chain/state/v1/state_trie.go @@ -12,6 +12,7 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/container/slice" "github.com/prysmaticlabs/prysm/crypto/hash" @@ -319,6 +320,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex) } return b.recomputeFieldTrie(validators, b.state.Validators) case balances: + if features.Get().EnableBalanceTrieComputation { + if b.rebuildTrie[field] { + maxBalCap := params.BeaconConfig().ValidatorRegistryLimit + elemSize := uint64(8) + balLimit := (maxBalCap*elemSize + 31) / 32 + err := b.resetFieldTrie(field, b.state.Balances, balLimit) + if err != nil { + return [32]byte{}, err + } + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(balances, b.state.Balances) + } return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances) case randaoMixes: if b.rebuildTrie[field] { diff --git a/beacon-chain/state/v1/state_trie_test.go b/beacon-chain/state/v1/state_trie_test.go index 542eb1a163..d3a155c4f4 100644 --- a/beacon-chain/state/v1/state_trie_test.go +++ b/beacon-chain/state/v1/state_trie_test.go @@ -7,6 +7,7 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/state" v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/encoding/bytesutil" eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" @@ -16,6 +17,12 @@ import ( "github.com/prysmaticlabs/prysm/testing/util" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestInitializeFromProto(t *testing.T) { testState, _ := util.DeterministicGenesisState(t, 64) pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe()) diff --git a/beacon-chain/state/v1/types.go b/beacon-chain/state/v1/types.go index efdfefb4b1..d9ba4b1e29 100644 --- a/beacon-chain/state/v1/types.go +++ b/beacon-chain/state/v1/types.go @@ -17,7 +17,6 @@ var _ state.BeaconState = (*BeaconState)(nil) func init() { fieldMap = make(map[types.FieldIndex]types.DataType, params.BeaconConfig().BeaconStateFieldCount) - // Initialize the fixed sized arrays. fieldMap[types.BlockRoots] = types.BasicArray fieldMap[types.StateRoots] = types.BasicArray @@ -28,6 +27,7 @@ func init() { fieldMap[types.Validators] = types.CompositeArray fieldMap[types.PreviousEpochAttestations] = types.CompositeArray fieldMap[types.CurrentEpochAttestations] = types.CompositeArray + fieldMap[types.Balances] = types.CompressedArray } // fieldMap keeps track of each field diff --git a/beacon-chain/state/v2/BUILD.bazel b/beacon-chain/state/v2/BUILD.bazel index 5c2995845f..063fbef75b 100644 --- a/beacon-chain/state/v2/BUILD.bazel +++ b/beacon-chain/state/v2/BUILD.bazel @@ -79,11 +79,13 @@ go_test( "//beacon-chain/state/stateutil:go_default_library", "//beacon-chain/state/types:go_default_library", "//beacon-chain/state/v1:go_default_library", + "//config/features:go_default_library", "//config/params:go_default_library", "//encoding/bytesutil:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//testing/assert:go_default_library", "//testing/require:go_default_library", "@com_github_prysmaticlabs_eth2_types//:go_default_library", + "@com_github_prysmaticlabs_go_bitfield//:go_default_library", ], ) diff --git a/beacon-chain/state/v2/setters_misc.go b/beacon-chain/state/v2/setters_misc.go index 5395bee2d4..87da6bcd30 100644 --- a/beacon-chain/state/v2/setters_misc.go +++ b/beacon-chain/state/v2/setters_misc.go @@ -5,6 +5,7 @@ import ( types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/crypto/hash" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "google.golang.org/protobuf/proto" @@ -171,6 +172,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin if b.rebuildTrie[index] { return } + // Exit early if balance trie computation isn't enabled. + if !features.Get().EnableBalanceTrieComputation && index == balances { + return + } totalIndicesLen := len(b.dirtyIndices[index]) + len(indices) if totalIndicesLen > indicesLimit { b.rebuildTrie[index] = true diff --git a/beacon-chain/state/v2/setters_test.go b/beacon-chain/state/v2/setters_test.go index f05031c32d..82e692000b 100644 --- a/beacon-chain/state/v2/setters_test.go +++ b/beacon-chain/state/v2/setters_test.go @@ -2,10 +2,15 @@ package v2 import ( "context" + "strconv" "testing" + types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/go-bitfield" + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types" "github.com/prysmaticlabs/prysm/config/params" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/testing/assert" @@ -57,3 +62,100 @@ func TestAppendBeyondIndicesLimit(t *testing.T) { assert.Equal(t, true, st.rebuildTrie[validators]) assert.Equal(t, len(st.dirtyIndices[validators]), 0) } + +func TestBeaconState_AppendBalanceWithTrie(t *testing.T) { + count := uint64(100) + vals := make([]*ethpb.Validator, 0, count) + bals := make([]uint64, 0, count) + for i := uint64(1); i < count; i++ { + someRoot := [32]byte{} + someKey := [48]byte{} + copy(someRoot[:], strconv.Itoa(int(i))) + copy(someKey[:], strconv.Itoa(int(i))) + vals = append(vals, ðpb.Validator{ + PublicKey: someKey[:], + WithdrawalCredentials: someRoot[:], + EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance, + Slashed: false, + ActivationEligibilityEpoch: 1, + ActivationEpoch: 1, + ExitEpoch: 1, + WithdrawableEpoch: 1, + }) + bals = append(bals, params.BeaconConfig().MaxEffectiveBalance) + } + zeroHash := params.BeaconConfig().ZeroHash + mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockblockRoots); i++ { + mockblockRoots[i] = zeroHash[:] + } + + mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot) + for i := 0; i < len(mockstateRoots); i++ { + mockstateRoots[i] = zeroHash[:] + } + mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector) + for i := 0; i < len(mockrandaoMixes); i++ { + mockrandaoMixes[i] = zeroHash[:] + } + var pubKeys [][]byte + for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ { + pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength)) + } + st, err := InitializeFromProto(ðpb.BeaconStateAltair{ + Slot: 1, + GenesisValidatorsRoot: make([]byte, 32), + Fork: ðpb.Fork{ + PreviousVersion: make([]byte, 4), + CurrentVersion: make([]byte, 4), + Epoch: 0, + }, + LatestBlockHeader: ðpb.BeaconBlockHeader{ + ParentRoot: make([]byte, 32), + StateRoot: make([]byte, 32), + BodyRoot: make([]byte, 32), + }, + CurrentEpochParticipation: []byte{}, + PreviousEpochParticipation: []byte{}, + Validators: vals, + Balances: bals, + Eth1Data: ð.Eth1Data{ + DepositRoot: make([]byte, 32), + BlockHash: make([]byte, 32), + }, + BlockRoots: mockblockRoots, + StateRoots: mockstateRoots, + RandaoMixes: mockrandaoMixes, + JustificationBits: bitfield.NewBitvector4(), + PreviousJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + CurrentJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + FinalizedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)}, + Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector), + CurrentSyncCommittee: ðpb.SyncCommittee{ + Pubkeys: pubKeys, + AggregatePubkey: make([]byte, 48), + }, + NextSyncCommittee: ðpb.SyncCommittee{ + Pubkeys: pubKeys, + AggregatePubkey: make([]byte, 48), + }, + }) + assert.NoError(t, err) + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + + for i := 0; i < 100; i++ { + if i%2 == 0 { + assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000)) + } + if i%3 == 0 { + assert.NoError(t, st.AppendBalance(1000)) + } + } + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(t, err) + newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances]) + wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances) + assert.NoError(t, err) + assert.Equal(t, wantedRt, newRt, "state roots are unequal") +} diff --git a/beacon-chain/state/v2/setters_validator.go b/beacon-chain/state/v2/setters_validator.go index f359a9d7e2..3299c09022 100644 --- a/beacon-chain/state/v2/setters_validator.go +++ b/beacon-chain/state/v2/setters_validator.go @@ -102,6 +102,7 @@ func (b *BeaconState) SetBalances(val []uint64) error { b.sharedFieldReferences[balances] = stateutil.NewRef(1) b.state.Balances = val + b.rebuildTrie[balances] = true b.markFieldAsDirty(balances) return nil } @@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64 bals[idx] = val b.state.Balances = bals b.markFieldAsDirty(balances) + b.addDirtyIndices(balances, []uint64{uint64(idx)}) return nil } @@ -219,7 +221,9 @@ func (b *BeaconState) AppendBalance(bal uint64) error { } b.state.Balances = append(bals, bal) + balIdx := len(b.state.Balances) - 1 b.markFieldAsDirty(balances) + b.addDirtyIndices(balances, []uint64{uint64(balIdx)}) return nil } diff --git a/beacon-chain/state/v2/state_trie.go b/beacon-chain/state/v2/state_trie.go index 40a6cab31c..8015605ab0 100644 --- a/beacon-chain/state/v2/state_trie.go +++ b/beacon-chain/state/v2/state_trie.go @@ -12,6 +12,7 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/beacon-chain/state/types" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/container/slice" "github.com/prysmaticlabs/prysm/crypto/hash" @@ -324,6 +325,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex) } return b.recomputeFieldTrie(validators, b.state.Validators) case balances: + if features.Get().EnableBalanceTrieComputation { + if b.rebuildTrie[field] { + maxBalCap := params.BeaconConfig().ValidatorRegistryLimit + elemSize := uint64(8) + balLimit := (maxBalCap*elemSize + 31) / 32 + err := b.resetFieldTrie(field, b.state.Balances, balLimit) + if err != nil { + return [32]byte{}, err + } + delete(b.rebuildTrie, field) + return b.stateFieldLeaves[field].TrieRoot() + } + return b.recomputeFieldTrie(balances, b.state.Balances) + } return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances) case randaoMixes: if b.rebuildTrie[field] { diff --git a/beacon-chain/state/v2/state_trie_test.go b/beacon-chain/state/v2/state_trie_test.go index 8f8508c70f..e0b5d295aa 100644 --- a/beacon-chain/state/v2/state_trie_test.go +++ b/beacon-chain/state/v2/state_trie_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/encoding/bytesutil" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" @@ -13,6 +14,12 @@ import ( "github.com/prysmaticlabs/prysm/testing/require" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestValidatorMap_DistinctCopy(t *testing.T) { count := uint64(100) vals := make([]*ethpb.Validator, 0, count) diff --git a/beacon-chain/state/v2/types.go b/beacon-chain/state/v2/types.go index aa70308ba6..193037e7ab 100644 --- a/beacon-chain/state/v2/types.go +++ b/beacon-chain/state/v2/types.go @@ -22,6 +22,9 @@ func init() { // Initialize the composite arrays. fieldMap[types.Eth1DataVotes] = types.CompositeArray fieldMap[types.Validators] = types.CompositeArray + + // Initialize Compressed Arrays + fieldMap[types.Balances] = types.CompressedArray } // fieldMap keeps track of each field diff --git a/config/features/config.go b/config/features/config.go index e86a178261..f531f860de 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -52,6 +52,7 @@ type Flags struct { EnableHistoricalSpaceRepresentation bool // EnableHistoricalSpaceRepresentation enables the saving of registry validators in separate buckets to save space EnableGetBlockOptimizations bool // EnableGetBlockOptimizations optimizes some elements of the GetBlock() function. EnableBatchVerification bool // EnableBatchVerification enables batch signature verification on gossip messages. + EnableBalanceTrieComputation bool // EnableBalanceTrieComputation enables our beacon state to use balance tries for hash tree root operations. // Logging related toggles. DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected. @@ -223,6 +224,10 @@ func ConfigureBeaconChain(ctx *cli.Context) { logEnabled(enableBatchGossipVerification) cfg.EnableBatchVerification = true } + if ctx.Bool(enableBalanceTrieComputation.Name) { + logEnabled(enableBalanceTrieComputation) + cfg.EnableBalanceTrieComputation = true + } Init(cfg) } diff --git a/config/features/flags.go b/config/features/flags.go index 5cd136a303..2400bf7898 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -139,6 +139,10 @@ var ( Name: "enable-batch-gossip-verification", Usage: "This enables batch verification of signatures received over gossip.", } + enableBalanceTrieComputation = &cli.BoolFlag{ + Name: "enable-balance-trie-computation", + Usage: "This enables optimized hash tree root operations for our balance field.", + } ) // devModeFlags holds list of flags that are set when development mode is on. @@ -147,6 +151,7 @@ var devModeFlags = []cli.Flag{ forceOptMaxCoverAggregationStategy, enableGetBlockOptimizations, enableBatchGossipVerification, + enableBalanceTrieComputation, } // ValidatorFlags contains a list of all the feature flags that apply to the validator client. @@ -192,6 +197,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{ disableCorrectlyPruneCanonicalAtts, disableActiveBalanceCache, enableBatchGossipVerification, + enableBalanceTrieComputation, }...) // E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E. diff --git a/encoding/ssz/helpers.go b/encoding/ssz/helpers.go index 04a1b69392..7532a7a89e 100644 --- a/encoding/ssz/helpers.go +++ b/encoding/ssz/helpers.go @@ -8,6 +8,7 @@ import ( "github.com/minio/sha256-simd" "github.com/pkg/errors" "github.com/prysmaticlabs/go-bitfield" + "github.com/prysmaticlabs/prysm/encoding/bytesutil" ) const bytesPerChunk = 32 @@ -113,6 +114,53 @@ func Pack(serializedItems [][]byte) ([][]byte, error) { return chunks, nil } +// PackByChunk a given byte array's final chunk with zeroes if needed. +func PackByChunk(serializedItems [][]byte) ([][bytesPerChunk]byte, error) { + emptyChunk := [bytesPerChunk]byte{} + // If there are no items, we return an empty chunk. + if len(serializedItems) == 0 { + return [][bytesPerChunk]byte{emptyChunk}, nil + } else if len(serializedItems[0]) == bytesPerChunk { + // If each item has exactly BYTES_PER_CHUNK length, we return the list of serialized items. + chunks := make([][bytesPerChunk]byte, 0, len(serializedItems)) + for _, c := range serializedItems { + chunks = append(chunks, bytesutil.ToBytes32(c)) + } + return chunks, nil + } + // We flatten the list in order to pack its items into byte chunks correctly. + var orderedItems []byte + for _, item := range serializedItems { + orderedItems = append(orderedItems, item...) + } + // If all our serialized item slices are length zero, we + // exit early. + if len(orderedItems) == 0 { + return [][bytesPerChunk]byte{emptyChunk}, nil + } + numItems := len(orderedItems) + var chunks [][bytesPerChunk]byte + for i := 0; i < numItems; i += bytesPerChunk { + j := i + bytesPerChunk + // We create our upper bound index of the chunk, if it is greater than numItems, + // we set it as numItems itself. + if j > numItems { + j = numItems + } + // We create chunks from the list of items based on the + // indices determined above. + // Right-pad the last chunk with zero bytes if it does not + // have length bytesPerChunk from the helper. + // The ToBytes32 helper allocates a 32-byte array, before + // copying the ordered items in. This ensures that even if + // the last chunk is != 32 in length, we will right-pad it with + // zero bytes. + chunks = append(chunks, bytesutil.ToBytes32(orderedItems[i:j])) + } + + return chunks, nil +} + // MixInLength appends hash length to root func MixInLength(root [32]byte, length []byte) [32]byte { var hash [32]byte diff --git a/encoding/ssz/helpers_test.go b/encoding/ssz/helpers_test.go index bf0d8842b5..ce24f5e639 100644 --- a/encoding/ssz/helpers_test.go +++ b/encoding/ssz/helpers_test.go @@ -94,6 +94,22 @@ func TestPack(t *testing.T) { } } +func TestPackByChunk(t *testing.T) { + byteSlice2D := [][]byte{ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7}, + {1, 1, 2, 3, 5, 8, 13, 21, 34}, + } + expected := [][32]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7, 1, 1}, + {2, 3, 5, 8, 13, 21, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}} + + result, err := ssz.PackByChunk(byteSlice2D) + require.NoError(t, err) + assert.Equal(t, len(expected), len(result)) + for i, v := range expected { + assert.DeepEqual(t, v, result[i]) + } +} + func TestMixInLength(t *testing.T) { byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32} length := []byte{1, 2, 3} diff --git a/testing/spectest/mainnet/altair/epoch_processing/BUILD.bazel b/testing/spectest/mainnet/altair/epoch_processing/BUILD.bazel index 838223c825..87b1b3f6bd 100644 --- a/testing/spectest/mainnet/altair/epoch_processing/BUILD.bazel +++ b/testing/spectest/mainnet/altair/epoch_processing/BUILD.bazel @@ -5,6 +5,7 @@ go_test( size = "small", srcs = [ "effective_balance_updates_test.go", + "epoch_processing_test.go", "eth1_data_reset_test.go", "historical_roots_update_test.go", "inactivity_updates_test.go", @@ -21,5 +22,8 @@ go_test( ], shard_count = 4, tags = ["spectest"], - deps = ["//testing/spectest/shared/altair/epoch_processing:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/altair/epoch_processing:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/altair/epoch_processing/epoch_processing_test.go b/testing/spectest/mainnet/altair/epoch_processing/epoch_processing_test.go new file mode 100644 index 0000000000..edb002ae95 --- /dev/null +++ b/testing/spectest/mainnet/altair/epoch_processing/epoch_processing_test.go @@ -0,0 +1,13 @@ +package epoch_processing + +import ( + "testing" + + "github.com/prysmaticlabs/prysm/config/features" +) + +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} diff --git a/testing/spectest/mainnet/altair/random/BUILD.bazel b/testing/spectest/mainnet/altair/random/BUILD.bazel index 404d7434c5..9adfafd7bb 100644 --- a/testing/spectest/mainnet/altair/random/BUILD.bazel +++ b/testing/spectest/mainnet/altair/random/BUILD.bazel @@ -8,5 +8,8 @@ go_test( "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/altair/sanity:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/altair/sanity:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/altair/random/random_test.go b/testing/spectest/mainnet/altair/random/random_test.go index 04b41e2f74..5b60987541 100644 --- a/testing/spectest/mainnet/altair/random/random_test.go +++ b/testing/spectest/mainnet/altair/random/random_test.go @@ -3,9 +3,16 @@ package random import ( "testing" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/sanity" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestMainnet_Altair_Random(t *testing.T) { sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests") } diff --git a/testing/spectest/mainnet/altair/rewards/BUILD.bazel b/testing/spectest/mainnet/altair/rewards/BUILD.bazel index 6abf6e69c5..ee4c80ff1f 100644 --- a/testing/spectest/mainnet/altair/rewards/BUILD.bazel +++ b/testing/spectest/mainnet/altair/rewards/BUILD.bazel @@ -8,5 +8,8 @@ go_test( "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/altair/rewards:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/altair/rewards:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/altair/rewards/rewards_test.go b/testing/spectest/mainnet/altair/rewards/rewards_test.go index baf2a8b4c5..0e9c3d95e9 100644 --- a/testing/spectest/mainnet/altair/rewards/rewards_test.go +++ b/testing/spectest/mainnet/altair/rewards/rewards_test.go @@ -3,9 +3,16 @@ package rewards import ( "testing" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/rewards" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestMainnet_Altair_Rewards(t *testing.T) { rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet") } diff --git a/testing/spectest/mainnet/altair/sanity/BUILD.bazel b/testing/spectest/mainnet/altair/sanity/BUILD.bazel index d5ec7b405c..1b3e7b1280 100644 --- a/testing/spectest/mainnet/altair/sanity/BUILD.bazel +++ b/testing/spectest/mainnet/altair/sanity/BUILD.bazel @@ -5,11 +5,15 @@ go_test( size = "medium", srcs = [ "blocks_test.go", + "sanity_test.go", "slots_test.go", ], data = glob(["*.yaml"]) + [ "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/altair/sanity:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/altair/sanity:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/altair/sanity/sanity_test.go b/testing/spectest/mainnet/altair/sanity/sanity_test.go new file mode 100644 index 0000000000..469b52d3fc --- /dev/null +++ b/testing/spectest/mainnet/altair/sanity/sanity_test.go @@ -0,0 +1,13 @@ +package sanity + +import ( + "testing" + + "github.com/prysmaticlabs/prysm/config/features" +) + +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} diff --git a/testing/spectest/mainnet/phase0/epoch_processing/BUILD.bazel b/testing/spectest/mainnet/phase0/epoch_processing/BUILD.bazel index 81dad5831b..fd31993271 100644 --- a/testing/spectest/mainnet/phase0/epoch_processing/BUILD.bazel +++ b/testing/spectest/mainnet/phase0/epoch_processing/BUILD.bazel @@ -22,6 +22,7 @@ go_test( shard_count = 4, tags = ["spectest"], deps = [ + "//config/features:go_default_library", "//config/params:go_default_library", "//testing/spectest/shared/phase0/epoch_processing:go_default_library", ], diff --git a/testing/spectest/mainnet/phase0/epoch_processing/epoch_processing_test.go b/testing/spectest/mainnet/phase0/epoch_processing/epoch_processing_test.go index b610764a1c..43a4d03a86 100644 --- a/testing/spectest/mainnet/phase0/epoch_processing/epoch_processing_test.go +++ b/testing/spectest/mainnet/phase0/epoch_processing/epoch_processing_test.go @@ -3,6 +3,7 @@ package epoch_processing import ( "testing" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" ) @@ -12,6 +13,8 @@ func TestMain(m *testing.M) { c := params.BeaconConfig() c.MinGenesisActiveValidatorCount = 16384 params.OverrideBeaconConfig(c) + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() m.Run() } diff --git a/testing/spectest/mainnet/phase0/random/BUILD.bazel b/testing/spectest/mainnet/phase0/random/BUILD.bazel index bc450e49b3..07d61cc785 100644 --- a/testing/spectest/mainnet/phase0/random/BUILD.bazel +++ b/testing/spectest/mainnet/phase0/random/BUILD.bazel @@ -8,5 +8,8 @@ go_test( "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/phase0/sanity:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/phase0/random/random_test.go b/testing/spectest/mainnet/phase0/random/random_test.go index 609a7e18ed..486aa3cbe8 100644 --- a/testing/spectest/mainnet/phase0/random/random_test.go +++ b/testing/spectest/mainnet/phase0/random/random_test.go @@ -3,9 +3,16 @@ package random import ( "testing" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/sanity" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestMainnet_Phase0_Random(t *testing.T) { sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests") } diff --git a/testing/spectest/mainnet/phase0/rewards/BUILD.bazel b/testing/spectest/mainnet/phase0/rewards/BUILD.bazel index 91761b53a4..ac03295973 100644 --- a/testing/spectest/mainnet/phase0/rewards/BUILD.bazel +++ b/testing/spectest/mainnet/phase0/rewards/BUILD.bazel @@ -8,5 +8,8 @@ go_test( "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/phase0/rewards:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/phase0/rewards:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/phase0/rewards/rewards_test.go b/testing/spectest/mainnet/phase0/rewards/rewards_test.go index 0e1b4a4f69..d448bf829a 100644 --- a/testing/spectest/mainnet/phase0/rewards/rewards_test.go +++ b/testing/spectest/mainnet/phase0/rewards/rewards_test.go @@ -3,9 +3,16 @@ package rewards import ( "testing" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/rewards" ) +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +} + func TestMainnet_Phase0_Rewards(t *testing.T) { rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet") } diff --git a/testing/spectest/mainnet/phase0/sanity/BUILD.bazel b/testing/spectest/mainnet/phase0/sanity/BUILD.bazel index 8d43eb9041..f08fc64ccb 100644 --- a/testing/spectest/mainnet/phase0/sanity/BUILD.bazel +++ b/testing/spectest/mainnet/phase0/sanity/BUILD.bazel @@ -5,11 +5,15 @@ go_test( size = "medium", srcs = [ "blocks_test.go", + "sanity_test.go", "slots_test.go", ], data = glob(["*.yaml"]) + [ "@consensus_spec_tests_mainnet//:test_data", ], tags = ["spectest"], - deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"], + deps = [ + "//config/features:go_default_library", + "//testing/spectest/shared/phase0/sanity:go_default_library", + ], ) diff --git a/testing/spectest/mainnet/phase0/sanity/sanity_test.go b/testing/spectest/mainnet/phase0/sanity/sanity_test.go new file mode 100644 index 0000000000..469b52d3fc --- /dev/null +++ b/testing/spectest/mainnet/phase0/sanity/sanity_test.go @@ -0,0 +1,13 @@ +package sanity + +import ( + "testing" + + "github.com/prysmaticlabs/prysm/config/features" +) + +func TestMain(m *testing.M) { + resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true}) + defer resetCfg() + m.Run() +}