mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-09 13:28:01 -05:00
Add Balance Field Trie (#9793)
* save stuff * fix in v1 * clean up more * fix bugs * add comments and clean up * add flag + test * add tests * fmt * radek's review * gaz * kasey's review * gaz and new conditional * improve naming
This commit is contained in:
@@ -12,6 +12,8 @@ go_library(
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//beacon-chain/state/types:go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//runtime/version:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
@@ -26,6 +28,7 @@ go_test(
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//beacon-chain/state/types:go_default_library",
|
||||
"//beacon-chain/state/v1:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
|
||||
@@ -18,6 +18,7 @@ type FieldTrie struct {
|
||||
field types.FieldIndex
|
||||
dataType types.DataType
|
||||
length uint64
|
||||
numOfElems int
|
||||
}
|
||||
|
||||
// NewFieldTrie is the constructor for the field trie data structure. It creates the corresponding
|
||||
@@ -26,18 +27,19 @@ type FieldTrie struct {
|
||||
func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) (*FieldTrie, error) {
|
||||
if elements == nil {
|
||||
return &FieldTrie{
|
||||
field: field,
|
||||
dataType: dataType,
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
field: field,
|
||||
dataType: dataType,
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
numOfElems: 0,
|
||||
}, nil
|
||||
}
|
||||
fieldRoots, err := fieldConverters(field, []uint64{}, elements, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateElements(field, elements, length); err != nil {
|
||||
if err := validateElements(field, dataType, elements, length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch dataType {
|
||||
@@ -53,8 +55,9 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
numOfElems: reflect.ValueOf(elements).Len(),
|
||||
}, nil
|
||||
case types.CompositeArray:
|
||||
case types.CompositeArray, types.CompressedArray:
|
||||
return &FieldTrie{
|
||||
fieldLayers: stateutil.ReturnTrieLayerVariable(fieldRoots, length),
|
||||
field: field,
|
||||
@@ -62,6 +65,7 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
numOfElems: reflect.ValueOf(elements).Len(),
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(dataType).Name())
|
||||
@@ -92,13 +96,40 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.ValueOf(elements).Len()
|
||||
return fieldRoot, nil
|
||||
case types.CompositeArray:
|
||||
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(fieldRoots, indices, f.fieldLayers)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.ValueOf(elements).Len()
|
||||
return stateutil.AddInMixin(fieldRoot, uint64(len(f.fieldLayers[0])))
|
||||
case types.CompressedArray:
|
||||
numOfElems, err := f.field.ElemsInChunk()
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
// We remove the duplicates here in order to prevent
|
||||
// duplicated insertions into the trie.
|
||||
newIndices := []uint64{}
|
||||
indexExists := make(map[uint64]bool)
|
||||
newRoots := make([][32]byte, 0, len(fieldRoots)/int(numOfElems))
|
||||
for i, idx := range indices {
|
||||
startIdx := idx / numOfElems
|
||||
if indexExists[startIdx] {
|
||||
continue
|
||||
}
|
||||
newIndices = append(newIndices, startIdx)
|
||||
indexExists[startIdx] = true
|
||||
newRoots = append(newRoots, fieldRoots[i])
|
||||
}
|
||||
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(newRoots, newIndices, f.fieldLayers)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.ValueOf(elements).Len()
|
||||
return stateutil.AddInMixin(fieldRoot, uint64(f.numOfElems))
|
||||
default:
|
||||
return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name())
|
||||
}
|
||||
@@ -109,11 +140,12 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
|
||||
func (f *FieldTrie) CopyTrie() *FieldTrie {
|
||||
if f.fieldLayers == nil {
|
||||
return &FieldTrie{
|
||||
field: f.field,
|
||||
dataType: f.dataType,
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: f.length,
|
||||
field: f.field,
|
||||
dataType: f.dataType,
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: f.length,
|
||||
numOfElems: f.numOfElems,
|
||||
}
|
||||
}
|
||||
dstFieldTrie := make([][]*[32]byte, len(f.fieldLayers))
|
||||
@@ -128,6 +160,7 @@ func (f *FieldTrie) CopyTrie() *FieldTrie {
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: f.length,
|
||||
numOfElems: f.numOfElems,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,6 +172,9 @@ func (f *FieldTrie) TrieRoot() ([32]byte, error) {
|
||||
case types.CompositeArray:
|
||||
trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0]
|
||||
return stateutil.AddInMixin(trieRoot, uint64(len(f.fieldLayers[0])))
|
||||
case types.CompressedArray:
|
||||
trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0]
|
||||
return stateutil.AddInMixin(trieRoot, uint64(f.numOfElems))
|
||||
default:
|
||||
return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package fieldtrie
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
@@ -8,20 +9,37 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/runtime/version"
|
||||
)
|
||||
|
||||
func (f *FieldTrie) validateIndices(idxs []uint64) error {
|
||||
length := f.length
|
||||
if f.dataType == types.CompressedArray {
|
||||
comLength, err := f.field.ElemsInChunk()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length *= comLength
|
||||
}
|
||||
for _, idx := range idxs {
|
||||
if idx >= f.length {
|
||||
return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, f.length)
|
||||
if idx >= length {
|
||||
return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, length)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateElements(field types.FieldIndex, elements interface{}, length uint64) error {
|
||||
func validateElements(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) error {
|
||||
if dataType == types.CompressedArray {
|
||||
comLength, err := field.ElemsInChunk()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length *= comLength
|
||||
}
|
||||
val := reflect.ValueOf(elements)
|
||||
if val.Len() > int(length) {
|
||||
return errors.Errorf("elements length is larger than expected for field %s: %d > %d", field.String(version.Phase0), val.Len(), length)
|
||||
@@ -38,21 +56,21 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac
|
||||
return nil, errors.Errorf("Wanted type of %v but got %v",
|
||||
reflect.TypeOf([][]byte{}).Name(), reflect.TypeOf(elements).Name())
|
||||
}
|
||||
return stateutil.HandleByteArrays(val, indices, convertAll)
|
||||
return handleByteArrays(val, indices, convertAll)
|
||||
case types.Eth1DataVotes:
|
||||
val, ok := elements.([]*ethpb.Eth1Data)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Wanted type of %v but got %v",
|
||||
reflect.TypeOf([]*ethpb.Eth1Data{}).Name(), reflect.TypeOf(elements).Name())
|
||||
}
|
||||
return HandleEth1DataSlice(val, indices, convertAll)
|
||||
return handleEth1DataSlice(val, indices, convertAll)
|
||||
case types.Validators:
|
||||
val, ok := elements.([]*ethpb.Validator)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Wanted type of %v but got %v",
|
||||
reflect.TypeOf([]*ethpb.Validator{}).Name(), reflect.TypeOf(elements).Name())
|
||||
}
|
||||
return stateutil.HandleValidatorSlice(val, indices, convertAll)
|
||||
return handleValidatorSlice(val, indices, convertAll)
|
||||
case types.PreviousEpochAttestations, types.CurrentEpochAttestations:
|
||||
val, ok := elements.([]*ethpb.PendingAttestation)
|
||||
if !ok {
|
||||
@@ -60,13 +78,87 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac
|
||||
reflect.TypeOf([]*ethpb.PendingAttestation{}).Name(), reflect.TypeOf(elements).Name())
|
||||
}
|
||||
return handlePendingAttestation(val, indices, convertAll)
|
||||
case types.Balances:
|
||||
val, ok := elements.([]uint64)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Wanted type of %v but got %v",
|
||||
reflect.TypeOf([]uint64{}).Name(), reflect.TypeOf(elements).Name())
|
||||
}
|
||||
return handleBalanceSlice(val, indices, convertAll)
|
||||
default:
|
||||
return [][32]byte{}, errors.Errorf("got unsupported type of %v", reflect.TypeOf(elements).Name())
|
||||
}
|
||||
}
|
||||
|
||||
// HandleEth1DataSlice processes a list of eth1data and indices into the appropriate roots.
|
||||
func HandleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
// handleByteArrays computes and returns byte arrays in a slice of root format.
|
||||
func handleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
rootCreator := func(input []byte) {
|
||||
newRoot := bytesutil.ToBytes32(input)
|
||||
roots = append(roots, newRoot)
|
||||
}
|
||||
if convertAll {
|
||||
for i := range val {
|
||||
rootCreator(val[i])
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val))
|
||||
}
|
||||
rootCreator(val[idx])
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
// handleValidatorSlice returns the validator indices in a slice of root format.
|
||||
func handleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
rootCreator := func(input *ethpb.Validator) error {
|
||||
newRoot, err := stateutil.ValidatorRootWithHasher(hasher, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roots = append(roots, newRoot)
|
||||
return nil
|
||||
}
|
||||
if convertAll {
|
||||
for i := range val {
|
||||
err := rootCreator(val[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val))
|
||||
}
|
||||
err := rootCreator(val[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
// handleEth1DataSlice processes a list of eth1data and indices into the appropriate roots.
|
||||
func handleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
@@ -141,3 +233,48 @@ func handlePendingAttestation(val []*ethpb.PendingAttestation, indices []uint64,
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
func handleBalanceSlice(val []uint64, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
if convertAll {
|
||||
balancesMarshaling := make([][]byte, 0)
|
||||
for _, b := range val {
|
||||
balanceBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(balanceBuf, b)
|
||||
balancesMarshaling = append(balancesMarshaling, balanceBuf)
|
||||
}
|
||||
balancesChunks, err := ssz.PackByChunk(balancesMarshaling)
|
||||
if err != nil {
|
||||
return [][32]byte{}, errors.Wrap(err, "could not pack balances into chunks")
|
||||
}
|
||||
return balancesChunks, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
numOfElems, err := types.Balances.ElemsInChunk()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roots := [][32]byte{}
|
||||
for _, idx := range indices {
|
||||
// We split the indexes into their relevant groups. Balances
|
||||
// are compressed according to 4 values -> 1 chunk.
|
||||
startIdx := idx / numOfElems
|
||||
startGroup := startIdx * numOfElems
|
||||
chunk := [32]byte{}
|
||||
sizeOfElem := len(chunk) / int(numOfElems)
|
||||
for i, j := 0, startGroup; j < startGroup+numOfElems; i, j = i+sizeOfElem, j+1 {
|
||||
wantedVal := uint64(0)
|
||||
// We are adding chunks in sets of 4, if the set is at the edge of the array
|
||||
// then you will need to zero out the rest of the chunk. Ex : 41 indexes,
|
||||
// so 41 % 4 = 1 . There are 3 indexes, which do not exist yet but we
|
||||
// have to add in as a root. These 3 indexes are then given a 'zero' value.
|
||||
if int(j) < len(val) {
|
||||
wantedVal = val[j]
|
||||
}
|
||||
binary.LittleEndian.PutUint64(chunk[i:i+sizeOfElem], wantedVal)
|
||||
}
|
||||
roots = append(roots, chunk)
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
return [][32]byte{}, nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package fieldtrie
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
)
|
||||
@@ -17,7 +22,64 @@ func Test_handlePendingAttestation_OutOfRange(t *testing.T) {
|
||||
func Test_handleEth1DataSlice_OutOfRange(t *testing.T) {
|
||||
items := make([]*ethpb.Eth1Data, 1)
|
||||
indices := []uint64{3}
|
||||
_, err := HandleEth1DataSlice(items, indices, false)
|
||||
_, err := handleEth1DataSlice(items, indices, false)
|
||||
assert.ErrorContains(t, "index 3 greater than number of items in eth1 data slice 1", err)
|
||||
|
||||
}
|
||||
|
||||
func Test_handleValidatorSlice_OutOfRange(t *testing.T) {
|
||||
vals := make([]*ethpb.Validator, 1)
|
||||
indices := []uint64{3}
|
||||
_, err := handleValidatorSlice(vals, indices, false)
|
||||
assert.ErrorContains(t, "index 3 greater than number of validators 1", err)
|
||||
}
|
||||
|
||||
func TestBalancesSlice_CorrectRoots_All(t *testing.T) {
|
||||
balances := []uint64{5, 2929, 34, 1291, 354305}
|
||||
roots, err := handleBalanceSlice(balances, []uint64{}, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
root1 := [32]byte{}
|
||||
binary.LittleEndian.PutUint64(root1[:8], balances[0])
|
||||
binary.LittleEndian.PutUint64(root1[8:16], balances[1])
|
||||
binary.LittleEndian.PutUint64(root1[16:24], balances[2])
|
||||
binary.LittleEndian.PutUint64(root1[24:32], balances[3])
|
||||
|
||||
root2 := [32]byte{}
|
||||
binary.LittleEndian.PutUint64(root2[:8], balances[4])
|
||||
|
||||
assert.DeepEqual(t, roots, [][32]byte{root1, root2})
|
||||
}
|
||||
|
||||
func TestBalancesSlice_CorrectRoots_Some(t *testing.T) {
|
||||
balances := []uint64{5, 2929, 34, 1291, 354305}
|
||||
roots, err := handleBalanceSlice(balances, []uint64{2, 3}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
root1 := [32]byte{}
|
||||
binary.LittleEndian.PutUint64(root1[:8], balances[0])
|
||||
binary.LittleEndian.PutUint64(root1[8:16], balances[1])
|
||||
binary.LittleEndian.PutUint64(root1[16:24], balances[2])
|
||||
binary.LittleEndian.PutUint64(root1[24:32], balances[3])
|
||||
|
||||
// Returns root for each indice(even if duplicated)
|
||||
assert.DeepEqual(t, roots, [][32]byte{root1, root1})
|
||||
}
|
||||
|
||||
func TestValidateIndices_CompressedField(t *testing.T) {
|
||||
fakeTrie := &FieldTrie{
|
||||
RWMutex: new(sync.RWMutex),
|
||||
reference: stateutil.NewRef(0),
|
||||
fieldLayers: nil,
|
||||
field: types.Balances,
|
||||
dataType: types.CompressedArray,
|
||||
length: params.BeaconConfig().ValidatorRegistryLimit / 4,
|
||||
numOfElems: 0,
|
||||
}
|
||||
goodIdx := params.BeaconConfig().ValidatorRegistryLimit - 1
|
||||
assert.NoError(t, fakeTrie.validateIndices([]uint64{goodIdx}))
|
||||
|
||||
badIdx := goodIdx + 1
|
||||
assert.ErrorContains(t, "invalid index for field balances", fakeTrie.validateIndices([]uint64{badIdx}))
|
||||
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"array_root.go",
|
||||
"block_header_root.go",
|
||||
"eth1_root.go",
|
||||
"participation_bit_root.go",
|
||||
@@ -49,7 +48,6 @@ go_test(
|
||||
"state_root_test.go",
|
||||
"stateutil_test.go",
|
||||
"trie_helpers_test.go",
|
||||
"validator_root_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package stateutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
)
|
||||
|
||||
// HandleByteArrays computes and returns byte arrays in a slice of root format.
|
||||
func HandleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
rootCreator := func(input []byte) {
|
||||
newRoot := bytesutil.ToBytes32(input)
|
||||
roots = append(roots, newRoot)
|
||||
}
|
||||
if convertAll {
|
||||
for i := range val {
|
||||
rootCreator(val[i])
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val))
|
||||
}
|
||||
rootCreator(val[idx])
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package stateutil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
@@ -127,42 +126,3 @@ func ValidatorEncKey(validator *ethpb.Validator) []byte {
|
||||
|
||||
return enc
|
||||
}
|
||||
|
||||
// HandleValidatorSlice returns the validator indices in a slice of root format.
|
||||
func HandleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
rootCreator := func(input *ethpb.Validator) error {
|
||||
newRoot, err := ValidatorRootWithHasher(hasher, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
roots = append(roots, newRoot)
|
||||
return nil
|
||||
}
|
||||
if convertAll {
|
||||
for i := range val {
|
||||
err := rootCreator(val[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val))
|
||||
}
|
||||
err := rootCreator(val[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package stateutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
)
|
||||
|
||||
func Test_handleValidatorSlice_OutOfRange(t *testing.T) {
|
||||
vals := make([]*ethpb.Validator, 1)
|
||||
indices := []uint64{3}
|
||||
_, err := HandleValidatorSlice(vals, indices, false)
|
||||
assert.ErrorContains(t, "index 3 greater than number of validators 1", err)
|
||||
}
|
||||
@@ -5,7 +5,10 @@ go_library(
|
||||
srcs = ["types.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/beacon-chain/state/types",
|
||||
visibility = ["//beacon-chain:__subpackages__"],
|
||||
deps = ["//runtime/version:go_default_library"],
|
||||
deps = [
|
||||
"//runtime/version:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/runtime/version"
|
||||
)
|
||||
|
||||
@@ -18,6 +19,10 @@ const (
|
||||
// CompositeArray represents a variable length array with
|
||||
// a non primitive type.
|
||||
CompositeArray
|
||||
// CompressedArray represents a variable length array which
|
||||
// can pack multiple elements into a leaf of the underlying
|
||||
// trie.
|
||||
CompressedArray
|
||||
)
|
||||
|
||||
// String returns the name of the field index.
|
||||
@@ -84,6 +89,17 @@ func (f FieldIndex) String(stateVersion int) string {
|
||||
}
|
||||
}
|
||||
|
||||
// ElemsInChunk returns the number of elements in the chunk (number of
|
||||
// elements that are able to be packed).
|
||||
func (f FieldIndex) ElemsInChunk() (uint64, error) {
|
||||
switch f {
|
||||
case Balances:
|
||||
return 4, nil
|
||||
default:
|
||||
return 0, errors.Errorf("field %d doesn't support element compression", f)
|
||||
}
|
||||
}
|
||||
|
||||
// Below we define a set of useful enum values for the field
|
||||
// indices of the beacon state. For example, genesisTime is the
|
||||
// 0th field of the beacon state. This is helpful when we are
|
||||
|
||||
@@ -87,6 +87,7 @@ go_test(
|
||||
"//beacon-chain/state:go_default_library",
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//beacon-chain/state/types:go_default_library",
|
||||
"//config/features:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -172,6 +173,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin
|
||||
if b.rebuildTrie[index] {
|
||||
return
|
||||
}
|
||||
// Exit early if balance trie computation isn't enabled.
|
||||
if !features.Get().EnableBalanceTrieComputation && index == balances {
|
||||
return
|
||||
}
|
||||
totalIndicesLen := len(b.dirtyIndices[index]) + len(indices)
|
||||
if totalIndicesLen > indicesLimit {
|
||||
b.rebuildTrie[index] = true
|
||||
|
||||
@@ -103,6 +103,7 @@ func (b *BeaconState) SetBalances(val []uint64) error {
|
||||
|
||||
b.state.Balances = val
|
||||
b.markFieldAsDirty(balances)
|
||||
b.rebuildTrie[balances] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64
|
||||
bals[idx] = val
|
||||
b.state.Balances = bals
|
||||
b.markFieldAsDirty(balances)
|
||||
b.addDirtyIndices(balances, []uint64{uint64(idx)})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,6 +221,8 @@ func (b *BeaconState) AppendBalance(bal uint64) error {
|
||||
}
|
||||
|
||||
b.state.Balances = append(bals, bal)
|
||||
balIdx := len(b.state.Balances) - 1
|
||||
b.markFieldAsDirty(balances)
|
||||
b.addDirtyIndices(balances, []uint64{uint64(balIdx)})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
@@ -104,3 +107,90 @@ func TestStateTrie_IsNil(t *testing.T) {
|
||||
nonNilState := &BeaconState{state: ðpb.BeaconState{}}
|
||||
assert.Equal(t, false, nonNilState.IsNil())
|
||||
}
|
||||
|
||||
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
count := uint64(100)
|
||||
vals := make([]*ethpb.Validator, 0, count)
|
||||
bals := make([]uint64, 0, count)
|
||||
for i := uint64(1); i < count; i++ {
|
||||
someRoot := [32]byte{}
|
||||
someKey := [48]byte{}
|
||||
copy(someRoot[:], strconv.Itoa(int(i)))
|
||||
copy(someKey[:], strconv.Itoa(int(i)))
|
||||
vals = append(vals, ðpb.Validator{
|
||||
PublicKey: someKey[:],
|
||||
WithdrawalCredentials: someRoot[:],
|
||||
EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance,
|
||||
Slashed: false,
|
||||
ActivationEligibilityEpoch: 1,
|
||||
ActivationEpoch: 1,
|
||||
ExitEpoch: 1,
|
||||
WithdrawableEpoch: 1,
|
||||
})
|
||||
bals = append(bals, params.BeaconConfig().MaxEffectiveBalance)
|
||||
}
|
||||
zeroHash := params.BeaconConfig().ZeroHash
|
||||
mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
|
||||
for i := 0; i < len(mockblockRoots); i++ {
|
||||
mockblockRoots[i] = zeroHash[:]
|
||||
}
|
||||
|
||||
mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
|
||||
for i := 0; i < len(mockstateRoots); i++ {
|
||||
mockstateRoots[i] = zeroHash[:]
|
||||
}
|
||||
mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector)
|
||||
for i := 0; i < len(mockrandaoMixes); i++ {
|
||||
mockrandaoMixes[i] = zeroHash[:]
|
||||
}
|
||||
var pubKeys [][]byte
|
||||
for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ {
|
||||
pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength))
|
||||
}
|
||||
st, err := InitializeFromProto(ðpb.BeaconState{
|
||||
Slot: 1,
|
||||
GenesisValidatorsRoot: make([]byte, 32),
|
||||
Fork: ðpb.Fork{
|
||||
PreviousVersion: make([]byte, 4),
|
||||
CurrentVersion: make([]byte, 4),
|
||||
Epoch: 0,
|
||||
},
|
||||
LatestBlockHeader: ðpb.BeaconBlockHeader{
|
||||
ParentRoot: make([]byte, 32),
|
||||
StateRoot: make([]byte, 32),
|
||||
BodyRoot: make([]byte, 32),
|
||||
},
|
||||
Validators: vals,
|
||||
Balances: bals,
|
||||
Eth1Data: ðpb.Eth1Data{
|
||||
DepositRoot: make([]byte, 32),
|
||||
BlockHash: make([]byte, 32),
|
||||
},
|
||||
BlockRoots: mockblockRoots,
|
||||
StateRoots: mockstateRoots,
|
||||
RandaoMixes: mockrandaoMixes,
|
||||
JustificationBits: bitfield.NewBitvector4(),
|
||||
PreviousJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
CurrentJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
FinalizedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000))
|
||||
}
|
||||
if i%3 == 0 {
|
||||
assert.NoError(t, st.AppendBalance(1000))
|
||||
}
|
||||
}
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/container/slice"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
@@ -319,6 +320,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex)
|
||||
}
|
||||
return b.recomputeFieldTrie(validators, b.state.Validators)
|
||||
case balances:
|
||||
if features.Get().EnableBalanceTrieComputation {
|
||||
if b.rebuildTrie[field] {
|
||||
maxBalCap := params.BeaconConfig().ValidatorRegistryLimit
|
||||
elemSize := uint64(8)
|
||||
balLimit := (maxBalCap*elemSize + 31) / 32
|
||||
err := b.resetFieldTrie(field, b.state.Balances, balLimit)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
delete(b.rebuildTrie, field)
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
return b.recomputeFieldTrie(balances, b.state.Balances)
|
||||
}
|
||||
return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances)
|
||||
case randaoMixes:
|
||||
if b.rebuildTrie[field] {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state"
|
||||
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
@@ -16,6 +17,12 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/testing/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestInitializeFromProto(t *testing.T) {
|
||||
testState, _ := util.DeterministicGenesisState(t, 64)
|
||||
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
|
||||
|
||||
@@ -17,7 +17,6 @@ var _ state.BeaconState = (*BeaconState)(nil)
|
||||
|
||||
func init() {
|
||||
fieldMap = make(map[types.FieldIndex]types.DataType, params.BeaconConfig().BeaconStateFieldCount)
|
||||
|
||||
// Initialize the fixed sized arrays.
|
||||
fieldMap[types.BlockRoots] = types.BasicArray
|
||||
fieldMap[types.StateRoots] = types.BasicArray
|
||||
@@ -28,6 +27,7 @@ func init() {
|
||||
fieldMap[types.Validators] = types.CompositeArray
|
||||
fieldMap[types.PreviousEpochAttestations] = types.CompositeArray
|
||||
fieldMap[types.CurrentEpochAttestations] = types.CompositeArray
|
||||
fieldMap[types.Balances] = types.CompressedArray
|
||||
}
|
||||
|
||||
// fieldMap keeps track of each field
|
||||
|
||||
@@ -79,11 +79,13 @@ go_test(
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//beacon-chain/state/types:go_default_library",
|
||||
"//beacon-chain/state/v1:go_default_library",
|
||||
"//config/features:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//testing/assert:go_default_library",
|
||||
"//testing/require:go_default_library",
|
||||
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -171,6 +172,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin
|
||||
if b.rebuildTrie[index] {
|
||||
return
|
||||
}
|
||||
// Exit early if balance trie computation isn't enabled.
|
||||
if !features.Get().EnableBalanceTrieComputation && index == balances {
|
||||
return
|
||||
}
|
||||
totalIndicesLen := len(b.dirtyIndices[index]) + len(indices)
|
||||
if totalIndicesLen > indicesLimit {
|
||||
b.rebuildTrie[index] = true
|
||||
|
||||
@@ -2,10 +2,15 @@ package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
@@ -57,3 +62,100 @@ func TestAppendBeyondIndicesLimit(t *testing.T) {
|
||||
assert.Equal(t, true, st.rebuildTrie[validators])
|
||||
assert.Equal(t, len(st.dirtyIndices[validators]), 0)
|
||||
}
|
||||
|
||||
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
count := uint64(100)
|
||||
vals := make([]*ethpb.Validator, 0, count)
|
||||
bals := make([]uint64, 0, count)
|
||||
for i := uint64(1); i < count; i++ {
|
||||
someRoot := [32]byte{}
|
||||
someKey := [48]byte{}
|
||||
copy(someRoot[:], strconv.Itoa(int(i)))
|
||||
copy(someKey[:], strconv.Itoa(int(i)))
|
||||
vals = append(vals, ðpb.Validator{
|
||||
PublicKey: someKey[:],
|
||||
WithdrawalCredentials: someRoot[:],
|
||||
EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance,
|
||||
Slashed: false,
|
||||
ActivationEligibilityEpoch: 1,
|
||||
ActivationEpoch: 1,
|
||||
ExitEpoch: 1,
|
||||
WithdrawableEpoch: 1,
|
||||
})
|
||||
bals = append(bals, params.BeaconConfig().MaxEffectiveBalance)
|
||||
}
|
||||
zeroHash := params.BeaconConfig().ZeroHash
|
||||
mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
|
||||
for i := 0; i < len(mockblockRoots); i++ {
|
||||
mockblockRoots[i] = zeroHash[:]
|
||||
}
|
||||
|
||||
mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
|
||||
for i := 0; i < len(mockstateRoots); i++ {
|
||||
mockstateRoots[i] = zeroHash[:]
|
||||
}
|
||||
mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector)
|
||||
for i := 0; i < len(mockrandaoMixes); i++ {
|
||||
mockrandaoMixes[i] = zeroHash[:]
|
||||
}
|
||||
var pubKeys [][]byte
|
||||
for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ {
|
||||
pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength))
|
||||
}
|
||||
st, err := InitializeFromProto(ðpb.BeaconStateAltair{
|
||||
Slot: 1,
|
||||
GenesisValidatorsRoot: make([]byte, 32),
|
||||
Fork: ðpb.Fork{
|
||||
PreviousVersion: make([]byte, 4),
|
||||
CurrentVersion: make([]byte, 4),
|
||||
Epoch: 0,
|
||||
},
|
||||
LatestBlockHeader: ðpb.BeaconBlockHeader{
|
||||
ParentRoot: make([]byte, 32),
|
||||
StateRoot: make([]byte, 32),
|
||||
BodyRoot: make([]byte, 32),
|
||||
},
|
||||
CurrentEpochParticipation: []byte{},
|
||||
PreviousEpochParticipation: []byte{},
|
||||
Validators: vals,
|
||||
Balances: bals,
|
||||
Eth1Data: ð.Eth1Data{
|
||||
DepositRoot: make([]byte, 32),
|
||||
BlockHash: make([]byte, 32),
|
||||
},
|
||||
BlockRoots: mockblockRoots,
|
||||
StateRoots: mockstateRoots,
|
||||
RandaoMixes: mockrandaoMixes,
|
||||
JustificationBits: bitfield.NewBitvector4(),
|
||||
PreviousJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
CurrentJustifiedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
FinalizedCheckpoint: ðpb.Checkpoint{Root: make([]byte, 32)},
|
||||
Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector),
|
||||
CurrentSyncCommittee: ðpb.SyncCommittee{
|
||||
Pubkeys: pubKeys,
|
||||
AggregatePubkey: make([]byte, 48),
|
||||
},
|
||||
NextSyncCommittee: ðpb.SyncCommittee{
|
||||
Pubkeys: pubKeys,
|
||||
AggregatePubkey: make([]byte, 48),
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000))
|
||||
}
|
||||
if i%3 == 0 {
|
||||
assert.NoError(t, st.AppendBalance(1000))
|
||||
}
|
||||
}
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
|
||||
}
|
||||
|
||||
@@ -102,6 +102,7 @@ func (b *BeaconState) SetBalances(val []uint64) error {
|
||||
b.sharedFieldReferences[balances] = stateutil.NewRef(1)
|
||||
|
||||
b.state.Balances = val
|
||||
b.rebuildTrie[balances] = true
|
||||
b.markFieldAsDirty(balances)
|
||||
return nil
|
||||
}
|
||||
@@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64
|
||||
bals[idx] = val
|
||||
b.state.Balances = bals
|
||||
b.markFieldAsDirty(balances)
|
||||
b.addDirtyIndices(balances, []uint64{uint64(idx)})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,7 +221,9 @@ func (b *BeaconState) AppendBalance(bal uint64) error {
|
||||
}
|
||||
|
||||
b.state.Balances = append(bals, bal)
|
||||
balIdx := len(b.state.Balances) - 1
|
||||
b.markFieldAsDirty(balances)
|
||||
b.addDirtyIndices(balances, []uint64{uint64(balIdx)})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/container/slice"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
@@ -324,6 +325,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex)
|
||||
}
|
||||
return b.recomputeFieldTrie(validators, b.state.Validators)
|
||||
case balances:
|
||||
if features.Get().EnableBalanceTrieComputation {
|
||||
if b.rebuildTrie[field] {
|
||||
maxBalCap := params.BeaconConfig().ValidatorRegistryLimit
|
||||
elemSize := uint64(8)
|
||||
balLimit := (maxBalCap*elemSize + 31) / 32
|
||||
err := b.resetFieldTrie(field, b.state.Balances, balLimit)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
delete(b.rebuildTrie, field)
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
return b.recomputeFieldTrie(balances, b.state.Balances)
|
||||
}
|
||||
return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances)
|
||||
case randaoMixes:
|
||||
if b.rebuildTrie[field] {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
@@ -13,6 +14,12 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/testing/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestValidatorMap_DistinctCopy(t *testing.T) {
|
||||
count := uint64(100)
|
||||
vals := make([]*ethpb.Validator, 0, count)
|
||||
|
||||
@@ -22,6 +22,9 @@ func init() {
|
||||
// Initialize the composite arrays.
|
||||
fieldMap[types.Eth1DataVotes] = types.CompositeArray
|
||||
fieldMap[types.Validators] = types.CompositeArray
|
||||
|
||||
// Initialize Compressed Arrays
|
||||
fieldMap[types.Balances] = types.CompressedArray
|
||||
}
|
||||
|
||||
// fieldMap keeps track of each field
|
||||
|
||||
@@ -52,6 +52,7 @@ type Flags struct {
|
||||
EnableHistoricalSpaceRepresentation bool // EnableHistoricalSpaceRepresentation enables the saving of registry validators in separate buckets to save space
|
||||
EnableGetBlockOptimizations bool // EnableGetBlockOptimizations optimizes some elements of the GetBlock() function.
|
||||
EnableBatchVerification bool // EnableBatchVerification enables batch signature verification on gossip messages.
|
||||
EnableBalanceTrieComputation bool // EnableBalanceTrieComputation enables our beacon state to use balance tries for hash tree root operations.
|
||||
// Logging related toggles.
|
||||
DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected.
|
||||
|
||||
@@ -223,6 +224,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
|
||||
logEnabled(enableBatchGossipVerification)
|
||||
cfg.EnableBatchVerification = true
|
||||
}
|
||||
if ctx.Bool(enableBalanceTrieComputation.Name) {
|
||||
logEnabled(enableBalanceTrieComputation)
|
||||
cfg.EnableBalanceTrieComputation = true
|
||||
}
|
||||
Init(cfg)
|
||||
}
|
||||
|
||||
|
||||
@@ -139,6 +139,10 @@ var (
|
||||
Name: "enable-batch-gossip-verification",
|
||||
Usage: "This enables batch verification of signatures received over gossip.",
|
||||
}
|
||||
enableBalanceTrieComputation = &cli.BoolFlag{
|
||||
Name: "enable-balance-trie-computation",
|
||||
Usage: "This enables optimized hash tree root operations for our balance field.",
|
||||
}
|
||||
)
|
||||
|
||||
// devModeFlags holds list of flags that are set when development mode is on.
|
||||
@@ -147,6 +151,7 @@ var devModeFlags = []cli.Flag{
|
||||
forceOptMaxCoverAggregationStategy,
|
||||
enableGetBlockOptimizations,
|
||||
enableBatchGossipVerification,
|
||||
enableBalanceTrieComputation,
|
||||
}
|
||||
|
||||
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
|
||||
@@ -192,6 +197,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
|
||||
disableCorrectlyPruneCanonicalAtts,
|
||||
disableActiveBalanceCache,
|
||||
enableBatchGossipVerification,
|
||||
enableBalanceTrieComputation,
|
||||
}...)
|
||||
|
||||
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/minio/sha256-simd"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
)
|
||||
|
||||
const bytesPerChunk = 32
|
||||
@@ -113,6 +114,53 @@ func Pack(serializedItems [][]byte) ([][]byte, error) {
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// PackByChunk a given byte array's final chunk with zeroes if needed.
|
||||
func PackByChunk(serializedItems [][]byte) ([][bytesPerChunk]byte, error) {
|
||||
emptyChunk := [bytesPerChunk]byte{}
|
||||
// If there are no items, we return an empty chunk.
|
||||
if len(serializedItems) == 0 {
|
||||
return [][bytesPerChunk]byte{emptyChunk}, nil
|
||||
} else if len(serializedItems[0]) == bytesPerChunk {
|
||||
// If each item has exactly BYTES_PER_CHUNK length, we return the list of serialized items.
|
||||
chunks := make([][bytesPerChunk]byte, 0, len(serializedItems))
|
||||
for _, c := range serializedItems {
|
||||
chunks = append(chunks, bytesutil.ToBytes32(c))
|
||||
}
|
||||
return chunks, nil
|
||||
}
|
||||
// We flatten the list in order to pack its items into byte chunks correctly.
|
||||
var orderedItems []byte
|
||||
for _, item := range serializedItems {
|
||||
orderedItems = append(orderedItems, item...)
|
||||
}
|
||||
// If all our serialized item slices are length zero, we
|
||||
// exit early.
|
||||
if len(orderedItems) == 0 {
|
||||
return [][bytesPerChunk]byte{emptyChunk}, nil
|
||||
}
|
||||
numItems := len(orderedItems)
|
||||
var chunks [][bytesPerChunk]byte
|
||||
for i := 0; i < numItems; i += bytesPerChunk {
|
||||
j := i + bytesPerChunk
|
||||
// We create our upper bound index of the chunk, if it is greater than numItems,
|
||||
// we set it as numItems itself.
|
||||
if j > numItems {
|
||||
j = numItems
|
||||
}
|
||||
// We create chunks from the list of items based on the
|
||||
// indices determined above.
|
||||
// Right-pad the last chunk with zero bytes if it does not
|
||||
// have length bytesPerChunk from the helper.
|
||||
// The ToBytes32 helper allocates a 32-byte array, before
|
||||
// copying the ordered items in. This ensures that even if
|
||||
// the last chunk is != 32 in length, we will right-pad it with
|
||||
// zero bytes.
|
||||
chunks = append(chunks, bytesutil.ToBytes32(orderedItems[i:j]))
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// MixInLength appends hash length to root
|
||||
func MixInLength(root [32]byte, length []byte) [32]byte {
|
||||
var hash [32]byte
|
||||
|
||||
@@ -94,6 +94,22 @@ func TestPack(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackByChunk(t *testing.T) {
|
||||
byteSlice2D := [][]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7},
|
||||
{1, 1, 2, 3, 5, 8, 13, 21, 34},
|
||||
}
|
||||
expected := [][32]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7, 1, 1},
|
||||
{2, 3, 5, 8, 13, 21, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}
|
||||
|
||||
result, err := ssz.PackByChunk(byteSlice2D)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expected), len(result))
|
||||
for i, v := range expected {
|
||||
assert.DeepEqual(t, v, result[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMixInLength(t *testing.T) {
|
||||
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
||||
length := []byte{1, 2, 3}
|
||||
|
||||
@@ -5,6 +5,7 @@ go_test(
|
||||
size = "small",
|
||||
srcs = [
|
||||
"effective_balance_updates_test.go",
|
||||
"epoch_processing_test.go",
|
||||
"eth1_data_reset_test.go",
|
||||
"historical_roots_update_test.go",
|
||||
"inactivity_updates_test.go",
|
||||
@@ -21,5 +22,8 @@ go_test(
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/altair/epoch_processing:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/altair/epoch_processing:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package epoch_processing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
@@ -8,5 +8,8 @@ go_test(
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/altair/sanity:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/altair/sanity:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,9 +3,16 @@ package random
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/sanity"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestMainnet_Altair_Random(t *testing.T) {
|
||||
sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests")
|
||||
}
|
||||
|
||||
@@ -8,5 +8,8 @@ go_test(
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/altair/rewards:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/altair/rewards:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,9 +3,16 @@ package rewards
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/rewards"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestMainnet_Altair_Rewards(t *testing.T) {
|
||||
rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet")
|
||||
}
|
||||
|
||||
@@ -5,11 +5,15 @@ go_test(
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"blocks_test.go",
|
||||
"sanity_test.go",
|
||||
"slots_test.go",
|
||||
],
|
||||
data = glob(["*.yaml"]) + [
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/altair/sanity:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/altair/sanity:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
13
testing/spectest/mainnet/altair/sanity/sanity_test.go
Normal file
13
testing/spectest/mainnet/altair/sanity/sanity_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package sanity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
@@ -22,6 +22,7 @@ go_test(
|
||||
shard_count = 4,
|
||||
tags = ["spectest"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//testing/spectest/shared/phase0/epoch_processing:go_default_library",
|
||||
],
|
||||
|
||||
@@ -3,6 +3,7 @@ package epoch_processing
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
)
|
||||
|
||||
@@ -12,6 +13,8 @@ func TestMain(m *testing.M) {
|
||||
c := params.BeaconConfig()
|
||||
c.MinGenesisActiveValidatorCount = 16384
|
||||
params.OverrideBeaconConfig(c)
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
@@ -8,5 +8,8 @@ go_test(
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/phase0/sanity:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,9 +3,16 @@ package random
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/sanity"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestMainnet_Phase0_Random(t *testing.T) {
|
||||
sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests")
|
||||
}
|
||||
|
||||
@@ -8,5 +8,8 @@ go_test(
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/phase0/rewards:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/phase0/rewards:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -3,9 +3,16 @@ package rewards
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/rewards"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestMainnet_Phase0_Rewards(t *testing.T) {
|
||||
rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet")
|
||||
}
|
||||
|
||||
@@ -5,11 +5,15 @@ go_test(
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"blocks_test.go",
|
||||
"sanity_test.go",
|
||||
"slots_test.go",
|
||||
],
|
||||
data = glob(["*.yaml"]) + [
|
||||
"@consensus_spec_tests_mainnet//:test_data",
|
||||
],
|
||||
tags = ["spectest"],
|
||||
deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//testing/spectest/shared/phase0/sanity:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
13
testing/spectest/mainnet/phase0/sanity/sanity_test.go
Normal file
13
testing/spectest/mainnet/phase0/sanity/sanity_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package sanity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
|
||||
defer resetCfg()
|
||||
m.Run()
|
||||
}
|
||||
Reference in New Issue
Block a user