Add Balance Field Trie (#9793)

* save stuff

* fix in v1

* clean up more

* fix bugs

* add comments and clean up

* add flag + test

* add tests

* fmt

* radek's review

* gaz

* kasey's review

* gaz and new conditional

* improve naming
This commit is contained in:
Nishant Das
2021-11-19 20:01:15 +08:00
committed by GitHub
parent ee52f8dff3
commit 0ade1f121d
44 changed files with 717 additions and 122 deletions

View File

@@ -12,6 +12,8 @@ go_library(
"//beacon-chain/state/stateutil:go_default_library",
"//beacon-chain/state/types:go_default_library",
"//crypto/hash:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//runtime/version:go_default_library",
"@com_github_pkg_errors//:go_default_library",
@@ -26,6 +28,7 @@ go_test(
],
embed = [":go_default_library"],
deps = [
"//beacon-chain/state/stateutil:go_default_library",
"//beacon-chain/state/types:go_default_library",
"//beacon-chain/state/v1:go_default_library",
"//config/params:go_default_library",

View File

@@ -18,6 +18,7 @@ type FieldTrie struct {
field types.FieldIndex
dataType types.DataType
length uint64
numOfElems int
}
// NewFieldTrie is the constructor for the field trie data structure. It creates the corresponding
@@ -26,18 +27,19 @@ type FieldTrie struct {
func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) (*FieldTrie, error) {
if elements == nil {
return &FieldTrie{
field: field,
dataType: dataType,
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: length,
field: field,
dataType: dataType,
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: length,
numOfElems: 0,
}, nil
}
fieldRoots, err := fieldConverters(field, []uint64{}, elements, true)
if err != nil {
return nil, err
}
if err := validateElements(field, elements, length); err != nil {
if err := validateElements(field, dataType, elements, length); err != nil {
return nil, err
}
switch dataType {
@@ -53,8 +55,9 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: length,
numOfElems: reflect.ValueOf(elements).Len(),
}, nil
case types.CompositeArray:
case types.CompositeArray, types.CompressedArray:
return &FieldTrie{
fieldLayers: stateutil.ReturnTrieLayerVariable(fieldRoots, length),
field: field,
@@ -62,6 +65,7 @@ func NewFieldTrie(field types.FieldIndex, dataType types.DataType, elements inte
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: length,
numOfElems: reflect.ValueOf(elements).Len(),
}, nil
default:
return nil, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(dataType).Name())
@@ -92,13 +96,40 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
if err != nil {
return [32]byte{}, err
}
f.numOfElems = reflect.ValueOf(elements).Len()
return fieldRoot, nil
case types.CompositeArray:
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(fieldRoots, indices, f.fieldLayers)
if err != nil {
return [32]byte{}, err
}
f.numOfElems = reflect.ValueOf(elements).Len()
return stateutil.AddInMixin(fieldRoot, uint64(len(f.fieldLayers[0])))
case types.CompressedArray:
numOfElems, err := f.field.ElemsInChunk()
if err != nil {
return [32]byte{}, err
}
// We remove the duplicates here in order to prevent
// duplicated insertions into the trie.
newIndices := []uint64{}
indexExists := make(map[uint64]bool)
newRoots := make([][32]byte, 0, len(fieldRoots)/int(numOfElems))
for i, idx := range indices {
startIdx := idx / numOfElems
if indexExists[startIdx] {
continue
}
newIndices = append(newIndices, startIdx)
indexExists[startIdx] = true
newRoots = append(newRoots, fieldRoots[i])
}
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(newRoots, newIndices, f.fieldLayers)
if err != nil {
return [32]byte{}, err
}
f.numOfElems = reflect.ValueOf(elements).Len()
return stateutil.AddInMixin(fieldRoot, uint64(f.numOfElems))
default:
return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name())
}
@@ -109,11 +140,12 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
func (f *FieldTrie) CopyTrie() *FieldTrie {
if f.fieldLayers == nil {
return &FieldTrie{
field: f.field,
dataType: f.dataType,
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: f.length,
field: f.field,
dataType: f.dataType,
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: f.length,
numOfElems: f.numOfElems,
}
}
dstFieldTrie := make([][]*[32]byte, len(f.fieldLayers))
@@ -128,6 +160,7 @@ func (f *FieldTrie) CopyTrie() *FieldTrie {
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
length: f.length,
numOfElems: f.numOfElems,
}
}
@@ -139,6 +172,9 @@ func (f *FieldTrie) TrieRoot() ([32]byte, error) {
case types.CompositeArray:
trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0]
return stateutil.AddInMixin(trieRoot, uint64(len(f.fieldLayers[0])))
case types.CompressedArray:
trieRoot := *f.fieldLayers[len(f.fieldLayers)-1][0]
return stateutil.AddInMixin(trieRoot, uint64(f.numOfElems))
default:
return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name())
}

View File

@@ -1,6 +1,7 @@
package fieldtrie
import (
"encoding/binary"
"fmt"
"reflect"
@@ -8,20 +9,37 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
"github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/runtime/version"
)
func (f *FieldTrie) validateIndices(idxs []uint64) error {
length := f.length
if f.dataType == types.CompressedArray {
comLength, err := f.field.ElemsInChunk()
if err != nil {
return err
}
length *= comLength
}
for _, idx := range idxs {
if idx >= f.length {
return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, f.length)
if idx >= length {
return errors.Errorf("invalid index for field %s: %d >= length %d", f.field.String(version.Phase0), idx, length)
}
}
return nil
}
func validateElements(field types.FieldIndex, elements interface{}, length uint64) error {
func validateElements(field types.FieldIndex, dataType types.DataType, elements interface{}, length uint64) error {
if dataType == types.CompressedArray {
comLength, err := field.ElemsInChunk()
if err != nil {
return err
}
length *= comLength
}
val := reflect.ValueOf(elements)
if val.Len() > int(length) {
return errors.Errorf("elements length is larger than expected for field %s: %d > %d", field.String(version.Phase0), val.Len(), length)
@@ -38,21 +56,21 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac
return nil, errors.Errorf("Wanted type of %v but got %v",
reflect.TypeOf([][]byte{}).Name(), reflect.TypeOf(elements).Name())
}
return stateutil.HandleByteArrays(val, indices, convertAll)
return handleByteArrays(val, indices, convertAll)
case types.Eth1DataVotes:
val, ok := elements.([]*ethpb.Eth1Data)
if !ok {
return nil, errors.Errorf("Wanted type of %v but got %v",
reflect.TypeOf([]*ethpb.Eth1Data{}).Name(), reflect.TypeOf(elements).Name())
}
return HandleEth1DataSlice(val, indices, convertAll)
return handleEth1DataSlice(val, indices, convertAll)
case types.Validators:
val, ok := elements.([]*ethpb.Validator)
if !ok {
return nil, errors.Errorf("Wanted type of %v but got %v",
reflect.TypeOf([]*ethpb.Validator{}).Name(), reflect.TypeOf(elements).Name())
}
return stateutil.HandleValidatorSlice(val, indices, convertAll)
return handleValidatorSlice(val, indices, convertAll)
case types.PreviousEpochAttestations, types.CurrentEpochAttestations:
val, ok := elements.([]*ethpb.PendingAttestation)
if !ok {
@@ -60,13 +78,87 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac
reflect.TypeOf([]*ethpb.PendingAttestation{}).Name(), reflect.TypeOf(elements).Name())
}
return handlePendingAttestation(val, indices, convertAll)
case types.Balances:
val, ok := elements.([]uint64)
if !ok {
return nil, errors.Errorf("Wanted type of %v but got %v",
reflect.TypeOf([]uint64{}).Name(), reflect.TypeOf(elements).Name())
}
return handleBalanceSlice(val, indices, convertAll)
default:
return [][32]byte{}, errors.Errorf("got unsupported type of %v", reflect.TypeOf(elements).Name())
}
}
// HandleEth1DataSlice processes a list of eth1data and indices into the appropriate roots.
func HandleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) {
// handleByteArrays computes and returns byte arrays in a slice of root format.
func handleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) {
length := len(indices)
if convertAll {
length = len(val)
}
roots := make([][32]byte, 0, length)
rootCreator := func(input []byte) {
newRoot := bytesutil.ToBytes32(input)
roots = append(roots, newRoot)
}
if convertAll {
for i := range val {
rootCreator(val[i])
}
return roots, nil
}
if len(val) > 0 {
for _, idx := range indices {
if idx > uint64(len(val))-1 {
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val))
}
rootCreator(val[idx])
}
}
return roots, nil
}
// handleValidatorSlice returns the validator indices in a slice of root format.
func handleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) {
length := len(indices)
if convertAll {
length = len(val)
}
roots := make([][32]byte, 0, length)
hasher := hash.CustomSHA256Hasher()
rootCreator := func(input *ethpb.Validator) error {
newRoot, err := stateutil.ValidatorRootWithHasher(hasher, input)
if err != nil {
return err
}
roots = append(roots, newRoot)
return nil
}
if convertAll {
for i := range val {
err := rootCreator(val[i])
if err != nil {
return nil, err
}
}
return roots, nil
}
if len(val) > 0 {
for _, idx := range indices {
if idx > uint64(len(val))-1 {
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val))
}
err := rootCreator(val[idx])
if err != nil {
return nil, err
}
}
}
return roots, nil
}
// handleEth1DataSlice processes a list of eth1data and indices into the appropriate roots.
func handleEth1DataSlice(val []*ethpb.Eth1Data, indices []uint64, convertAll bool) ([][32]byte, error) {
length := len(indices)
if convertAll {
length = len(val)
@@ -141,3 +233,48 @@ func handlePendingAttestation(val []*ethpb.PendingAttestation, indices []uint64,
}
return roots, nil
}
func handleBalanceSlice(val []uint64, indices []uint64, convertAll bool) ([][32]byte, error) {
if convertAll {
balancesMarshaling := make([][]byte, 0)
for _, b := range val {
balanceBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(balanceBuf, b)
balancesMarshaling = append(balancesMarshaling, balanceBuf)
}
balancesChunks, err := ssz.PackByChunk(balancesMarshaling)
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not pack balances into chunks")
}
return balancesChunks, nil
}
if len(val) > 0 {
numOfElems, err := types.Balances.ElemsInChunk()
if err != nil {
return nil, err
}
roots := [][32]byte{}
for _, idx := range indices {
// We split the indexes into their relevant groups. Balances
// are compressed according to 4 values -> 1 chunk.
startIdx := idx / numOfElems
startGroup := startIdx * numOfElems
chunk := [32]byte{}
sizeOfElem := len(chunk) / int(numOfElems)
for i, j := 0, startGroup; j < startGroup+numOfElems; i, j = i+sizeOfElem, j+1 {
wantedVal := uint64(0)
// We are adding chunks in sets of 4, if the set is at the edge of the array
// then you will need to zero out the rest of the chunk. Ex : 41 indexes,
// so 41 % 4 = 1 . There are 3 indexes, which do not exist yet but we
// have to add in as a root. These 3 indexes are then given a 'zero' value.
if int(j) < len(val) {
wantedVal = val[j]
}
binary.LittleEndian.PutUint64(chunk[i:i+sizeOfElem], wantedVal)
}
roots = append(roots, chunk)
}
return roots, nil
}
return [][32]byte{}, nil
}

View File

@@ -1,8 +1,13 @@
package fieldtrie
import (
"encoding/binary"
"sync"
"testing"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/params"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
)
@@ -17,7 +22,64 @@ func Test_handlePendingAttestation_OutOfRange(t *testing.T) {
func Test_handleEth1DataSlice_OutOfRange(t *testing.T) {
items := make([]*ethpb.Eth1Data, 1)
indices := []uint64{3}
_, err := HandleEth1DataSlice(items, indices, false)
_, err := handleEth1DataSlice(items, indices, false)
assert.ErrorContains(t, "index 3 greater than number of items in eth1 data slice 1", err)
}
func Test_handleValidatorSlice_OutOfRange(t *testing.T) {
vals := make([]*ethpb.Validator, 1)
indices := []uint64{3}
_, err := handleValidatorSlice(vals, indices, false)
assert.ErrorContains(t, "index 3 greater than number of validators 1", err)
}
func TestBalancesSlice_CorrectRoots_All(t *testing.T) {
balances := []uint64{5, 2929, 34, 1291, 354305}
roots, err := handleBalanceSlice(balances, []uint64{}, true)
assert.NoError(t, err)
root1 := [32]byte{}
binary.LittleEndian.PutUint64(root1[:8], balances[0])
binary.LittleEndian.PutUint64(root1[8:16], balances[1])
binary.LittleEndian.PutUint64(root1[16:24], balances[2])
binary.LittleEndian.PutUint64(root1[24:32], balances[3])
root2 := [32]byte{}
binary.LittleEndian.PutUint64(root2[:8], balances[4])
assert.DeepEqual(t, roots, [][32]byte{root1, root2})
}
func TestBalancesSlice_CorrectRoots_Some(t *testing.T) {
balances := []uint64{5, 2929, 34, 1291, 354305}
roots, err := handleBalanceSlice(balances, []uint64{2, 3}, false)
assert.NoError(t, err)
root1 := [32]byte{}
binary.LittleEndian.PutUint64(root1[:8], balances[0])
binary.LittleEndian.PutUint64(root1[8:16], balances[1])
binary.LittleEndian.PutUint64(root1[16:24], balances[2])
binary.LittleEndian.PutUint64(root1[24:32], balances[3])
// Returns root for each indice(even if duplicated)
assert.DeepEqual(t, roots, [][32]byte{root1, root1})
}
func TestValidateIndices_CompressedField(t *testing.T) {
fakeTrie := &FieldTrie{
RWMutex: new(sync.RWMutex),
reference: stateutil.NewRef(0),
fieldLayers: nil,
field: types.Balances,
dataType: types.CompressedArray,
length: params.BeaconConfig().ValidatorRegistryLimit / 4,
numOfElems: 0,
}
goodIdx := params.BeaconConfig().ValidatorRegistryLimit - 1
assert.NoError(t, fakeTrie.validateIndices([]uint64{goodIdx}))
badIdx := goodIdx + 1
assert.ErrorContains(t, "invalid index for field balances", fakeTrie.validateIndices([]uint64{badIdx}))
}

View File

@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"array_root.go",
"block_header_root.go",
"eth1_root.go",
"participation_bit_root.go",
@@ -49,7 +48,6 @@ go_test(
"state_root_test.go",
"stateutil_test.go",
"trie_helpers_test.go",
"validator_root_test.go",
],
embed = [":go_default_library"],
deps = [

View File

@@ -1,35 +0,0 @@
package stateutil
import (
"fmt"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
)
// HandleByteArrays computes and returns byte arrays in a slice of root format.
func HandleByteArrays(val [][]byte, indices []uint64, convertAll bool) ([][32]byte, error) {
length := len(indices)
if convertAll {
length = len(val)
}
roots := make([][32]byte, 0, length)
rootCreator := func(input []byte) {
newRoot := bytesutil.ToBytes32(input)
roots = append(roots, newRoot)
}
if convertAll {
for i := range val {
rootCreator(val[i])
}
return roots, nil
}
if len(val) > 0 {
for _, idx := range indices {
if idx > uint64(len(val))-1 {
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val))
}
rootCreator(val[idx])
}
}
return roots, nil
}

View File

@@ -2,7 +2,6 @@ package stateutil
import (
"encoding/binary"
"fmt"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/params"
@@ -127,42 +126,3 @@ func ValidatorEncKey(validator *ethpb.Validator) []byte {
return enc
}
// HandleValidatorSlice returns the validator indices in a slice of root format.
func HandleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) {
length := len(indices)
if convertAll {
length = len(val)
}
roots := make([][32]byte, 0, length)
hasher := hash.CustomSHA256Hasher()
rootCreator := func(input *ethpb.Validator) error {
newRoot, err := ValidatorRootWithHasher(hasher, input)
if err != nil {
return err
}
roots = append(roots, newRoot)
return nil
}
if convertAll {
for i := range val {
err := rootCreator(val[i])
if err != nil {
return nil, err
}
}
return roots, nil
}
if len(val) > 0 {
for _, idx := range indices {
if idx > uint64(len(val))-1 {
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val))
}
err := rootCreator(val[idx])
if err != nil {
return nil, err
}
}
}
return roots, nil
}

View File

@@ -1,15 +0,0 @@
package stateutil
import (
"testing"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
)
func Test_handleValidatorSlice_OutOfRange(t *testing.T) {
vals := make([]*ethpb.Validator, 1)
indices := []uint64{3}
_, err := HandleValidatorSlice(vals, indices, false)
assert.ErrorContains(t, "index 3 greater than number of validators 1", err)
}

View File

@@ -5,7 +5,10 @@ go_library(
srcs = ["types.go"],
importpath = "github.com/prysmaticlabs/prysm/beacon-chain/state/types",
visibility = ["//beacon-chain:__subpackages__"],
deps = ["//runtime/version:go_default_library"],
deps = [
"//runtime/version:go_default_library",
"@com_github_pkg_errors//:go_default_library",
],
)
go_test(

View File

@@ -1,6 +1,7 @@
package types
import (
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/runtime/version"
)
@@ -18,6 +19,10 @@ const (
// CompositeArray represents a variable length array with
// a non primitive type.
CompositeArray
// CompressedArray represents a variable length array which
// can pack multiple elements into a leaf of the underlying
// trie.
CompressedArray
)
// String returns the name of the field index.
@@ -84,6 +89,17 @@ func (f FieldIndex) String(stateVersion int) string {
}
}
// ElemsInChunk returns the number of elements in the chunk (number of
// elements that are able to be packed).
func (f FieldIndex) ElemsInChunk() (uint64, error) {
switch f {
case Balances:
return 4, nil
default:
return 0, errors.Errorf("field %d doesn't support element compression", f)
}
}
// Below we define a set of useful enum values for the field
// indices of the beacon state. For example, genesisTime is the
// 0th field of the beacon state. This is helpful when we are

View File

@@ -87,6 +87,7 @@ go_test(
"//beacon-chain/state:go_default_library",
"//beacon-chain/state/stateutil:go_default_library",
"//beacon-chain/state/types:go_default_library",
"//config/features:go_default_library",
"//config/params:go_default_library",
"//encoding/bytesutil:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",

View File

@@ -5,6 +5,7 @@ import (
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/crypto/hash"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"google.golang.org/protobuf/proto"
@@ -172,6 +173,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin
if b.rebuildTrie[index] {
return
}
// Exit early if balance trie computation isn't enabled.
if !features.Get().EnableBalanceTrieComputation && index == balances {
return
}
totalIndicesLen := len(b.dirtyIndices[index]) + len(indices)
if totalIndicesLen > indicesLimit {
b.rebuildTrie[index] = true

View File

@@ -103,6 +103,7 @@ func (b *BeaconState) SetBalances(val []uint64) error {
b.state.Balances = val
b.markFieldAsDirty(balances)
b.rebuildTrie[balances] = true
return nil
}
@@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64
bals[idx] = val
b.state.Balances = bals
b.markFieldAsDirty(balances)
b.addDirtyIndices(balances, []uint64{uint64(idx)})
return nil
}
@@ -219,6 +221,8 @@ func (b *BeaconState) AppendBalance(bal uint64) error {
}
b.state.Balances = append(bals, bal)
balIdx := len(b.state.Balances) - 1
b.markFieldAsDirty(balances)
b.addDirtyIndices(balances, []uint64{uint64(balIdx)})
return nil
}

View File

@@ -1,10 +1,13 @@
package v1
import (
"context"
"strconv"
"sync"
"testing"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
@@ -104,3 +107,90 @@ func TestStateTrie_IsNil(t *testing.T) {
nonNilState := &BeaconState{state: &ethpb.BeaconState{}}
assert.Equal(t, false, nonNilState.IsNil())
}
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
count := uint64(100)
vals := make([]*ethpb.Validator, 0, count)
bals := make([]uint64, 0, count)
for i := uint64(1); i < count; i++ {
someRoot := [32]byte{}
someKey := [48]byte{}
copy(someRoot[:], strconv.Itoa(int(i)))
copy(someKey[:], strconv.Itoa(int(i)))
vals = append(vals, &ethpb.Validator{
PublicKey: someKey[:],
WithdrawalCredentials: someRoot[:],
EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance,
Slashed: false,
ActivationEligibilityEpoch: 1,
ActivationEpoch: 1,
ExitEpoch: 1,
WithdrawableEpoch: 1,
})
bals = append(bals, params.BeaconConfig().MaxEffectiveBalance)
}
zeroHash := params.BeaconConfig().ZeroHash
mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
for i := 0; i < len(mockblockRoots); i++ {
mockblockRoots[i] = zeroHash[:]
}
mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
for i := 0; i < len(mockstateRoots); i++ {
mockstateRoots[i] = zeroHash[:]
}
mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector)
for i := 0; i < len(mockrandaoMixes); i++ {
mockrandaoMixes[i] = zeroHash[:]
}
var pubKeys [][]byte
for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ {
pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength))
}
st, err := InitializeFromProto(&ethpb.BeaconState{
Slot: 1,
GenesisValidatorsRoot: make([]byte, 32),
Fork: &ethpb.Fork{
PreviousVersion: make([]byte, 4),
CurrentVersion: make([]byte, 4),
Epoch: 0,
},
LatestBlockHeader: &ethpb.BeaconBlockHeader{
ParentRoot: make([]byte, 32),
StateRoot: make([]byte, 32),
BodyRoot: make([]byte, 32),
},
Validators: vals,
Balances: bals,
Eth1Data: &ethpb.Eth1Data{
DepositRoot: make([]byte, 32),
BlockHash: make([]byte, 32),
},
BlockRoots: mockblockRoots,
StateRoots: mockstateRoots,
RandaoMixes: mockrandaoMixes,
JustificationBits: bitfield.NewBitvector4(),
PreviousJustifiedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
CurrentJustifiedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
FinalizedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector),
})
assert.NoError(t, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
for i := 0; i < 100; i++ {
if i%2 == 0 {
assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000))
}
if i%3 == 0 {
assert.NoError(t, st.AppendBalance(1000))
}
}
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances)
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/container/slice"
"github.com/prysmaticlabs/prysm/crypto/hash"
@@ -319,6 +320,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex)
}
return b.recomputeFieldTrie(validators, b.state.Validators)
case balances:
if features.Get().EnableBalanceTrieComputation {
if b.rebuildTrie[field] {
maxBalCap := params.BeaconConfig().ValidatorRegistryLimit
elemSize := uint64(8)
balLimit := (maxBalCap*elemSize + 31) / 32
err := b.resetFieldTrie(field, b.state.Balances, balLimit)
if err != nil {
return [32]byte{}, err
}
delete(b.rebuildTrie, field)
return b.stateFieldLeaves[field].TrieRoot()
}
return b.recomputeFieldTrie(balances, b.state.Balances)
}
return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances)
case randaoMixes:
if b.rebuildTrie[field] {

View File

@@ -7,6 +7,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/state"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
@@ -16,6 +17,12 @@ import (
"github.com/prysmaticlabs/prysm/testing/util"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestInitializeFromProto(t *testing.T) {
testState, _ := util.DeterministicGenesisState(t, 64)
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())

View File

@@ -17,7 +17,6 @@ var _ state.BeaconState = (*BeaconState)(nil)
func init() {
fieldMap = make(map[types.FieldIndex]types.DataType, params.BeaconConfig().BeaconStateFieldCount)
// Initialize the fixed sized arrays.
fieldMap[types.BlockRoots] = types.BasicArray
fieldMap[types.StateRoots] = types.BasicArray
@@ -28,6 +27,7 @@ func init() {
fieldMap[types.Validators] = types.CompositeArray
fieldMap[types.PreviousEpochAttestations] = types.CompositeArray
fieldMap[types.CurrentEpochAttestations] = types.CompositeArray
fieldMap[types.Balances] = types.CompressedArray
}
// fieldMap keeps track of each field

View File

@@ -79,11 +79,13 @@ go_test(
"//beacon-chain/state/stateutil:go_default_library",
"//beacon-chain/state/types:go_default_library",
"//beacon-chain/state/v1:go_default_library",
"//config/features:go_default_library",
"//config/params:go_default_library",
"//encoding/bytesutil:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//testing/assert:go_default_library",
"//testing/require:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
],
)

View File

@@ -5,6 +5,7 @@ import (
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/crypto/hash"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"google.golang.org/protobuf/proto"
@@ -171,6 +172,10 @@ func (b *BeaconState) addDirtyIndices(index stateTypes.FieldIndex, indices []uin
if b.rebuildTrie[index] {
return
}
// Exit early if balance trie computation isn't enabled.
if !features.Get().EnableBalanceTrieComputation && index == balances {
return
}
totalIndicesLen := len(b.dirtyIndices[index]) + len(indices)
if totalIndicesLen > indicesLimit {
b.rebuildTrie[index] = true

View File

@@ -2,10 +2,15 @@ package v2
import (
"context"
"strconv"
"testing"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
stateTypes "github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
eth "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
@@ -57,3 +62,100 @@ func TestAppendBeyondIndicesLimit(t *testing.T) {
assert.Equal(t, true, st.rebuildTrie[validators])
assert.Equal(t, len(st.dirtyIndices[validators]), 0)
}
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
count := uint64(100)
vals := make([]*ethpb.Validator, 0, count)
bals := make([]uint64, 0, count)
for i := uint64(1); i < count; i++ {
someRoot := [32]byte{}
someKey := [48]byte{}
copy(someRoot[:], strconv.Itoa(int(i)))
copy(someKey[:], strconv.Itoa(int(i)))
vals = append(vals, &ethpb.Validator{
PublicKey: someKey[:],
WithdrawalCredentials: someRoot[:],
EffectiveBalance: params.BeaconConfig().MaxEffectiveBalance,
Slashed: false,
ActivationEligibilityEpoch: 1,
ActivationEpoch: 1,
ExitEpoch: 1,
WithdrawableEpoch: 1,
})
bals = append(bals, params.BeaconConfig().MaxEffectiveBalance)
}
zeroHash := params.BeaconConfig().ZeroHash
mockblockRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
for i := 0; i < len(mockblockRoots); i++ {
mockblockRoots[i] = zeroHash[:]
}
mockstateRoots := make([][]byte, params.BeaconConfig().SlotsPerHistoricalRoot)
for i := 0; i < len(mockstateRoots); i++ {
mockstateRoots[i] = zeroHash[:]
}
mockrandaoMixes := make([][]byte, params.BeaconConfig().EpochsPerHistoricalVector)
for i := 0; i < len(mockrandaoMixes); i++ {
mockrandaoMixes[i] = zeroHash[:]
}
var pubKeys [][]byte
for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ {
pubKeys = append(pubKeys, bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength))
}
st, err := InitializeFromProto(&ethpb.BeaconStateAltair{
Slot: 1,
GenesisValidatorsRoot: make([]byte, 32),
Fork: &ethpb.Fork{
PreviousVersion: make([]byte, 4),
CurrentVersion: make([]byte, 4),
Epoch: 0,
},
LatestBlockHeader: &ethpb.BeaconBlockHeader{
ParentRoot: make([]byte, 32),
StateRoot: make([]byte, 32),
BodyRoot: make([]byte, 32),
},
CurrentEpochParticipation: []byte{},
PreviousEpochParticipation: []byte{},
Validators: vals,
Balances: bals,
Eth1Data: &eth.Eth1Data{
DepositRoot: make([]byte, 32),
BlockHash: make([]byte, 32),
},
BlockRoots: mockblockRoots,
StateRoots: mockstateRoots,
RandaoMixes: mockrandaoMixes,
JustificationBits: bitfield.NewBitvector4(),
PreviousJustifiedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
CurrentJustifiedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
FinalizedCheckpoint: &ethpb.Checkpoint{Root: make([]byte, 32)},
Slashings: make([]uint64, params.BeaconConfig().EpochsPerSlashingsVector),
CurrentSyncCommittee: &ethpb.SyncCommittee{
Pubkeys: pubKeys,
AggregatePubkey: make([]byte, 48),
},
NextSyncCommittee: &ethpb.SyncCommittee{
Pubkeys: pubKeys,
AggregatePubkey: make([]byte, 48),
},
})
assert.NoError(t, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
for i := 0; i < 100; i++ {
if i%2 == 0 {
assert.NoError(t, st.UpdateBalancesAtIndex(types.ValidatorIndex(i), 1000))
}
if i%3 == 0 {
assert.NoError(t, st.AppendBalance(1000))
}
}
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.state.Balances)
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
}

View File

@@ -102,6 +102,7 @@ func (b *BeaconState) SetBalances(val []uint64) error {
b.sharedFieldReferences[balances] = stateutil.NewRef(1)
b.state.Balances = val
b.rebuildTrie[balances] = true
b.markFieldAsDirty(balances)
return nil
}
@@ -128,6 +129,7 @@ func (b *BeaconState) UpdateBalancesAtIndex(idx types.ValidatorIndex, val uint64
bals[idx] = val
b.state.Balances = bals
b.markFieldAsDirty(balances)
b.addDirtyIndices(balances, []uint64{uint64(idx)})
return nil
}
@@ -219,7 +221,9 @@ func (b *BeaconState) AppendBalance(bal uint64) error {
}
b.state.Balances = append(bals, bal)
balIdx := len(b.state.Balances) - 1
b.markFieldAsDirty(balances)
b.addDirtyIndices(balances, []uint64{uint64(balIdx)})
return nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/state/fieldtrie"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/beacon-chain/state/types"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/container/slice"
"github.com/prysmaticlabs/prysm/crypto/hash"
@@ -324,6 +325,20 @@ func (b *BeaconState) rootSelector(ctx context.Context, field types.FieldIndex)
}
return b.recomputeFieldTrie(validators, b.state.Validators)
case balances:
if features.Get().EnableBalanceTrieComputation {
if b.rebuildTrie[field] {
maxBalCap := params.BeaconConfig().ValidatorRegistryLimit
elemSize := uint64(8)
balLimit := (maxBalCap*elemSize + 31) / 32
err := b.resetFieldTrie(field, b.state.Balances, balLimit)
if err != nil {
return [32]byte{}, err
}
delete(b.rebuildTrie, field)
return b.stateFieldLeaves[field].TrieRoot()
}
return b.recomputeFieldTrie(balances, b.state.Balances)
}
return stateutil.Uint64ListRootWithRegistryLimit(b.state.Balances)
case randaoMixes:
if b.rebuildTrie[field] {

View File

@@ -6,6 +6,7 @@ import (
"testing"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
@@ -13,6 +14,12 @@ import (
"github.com/prysmaticlabs/prysm/testing/require"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestValidatorMap_DistinctCopy(t *testing.T) {
count := uint64(100)
vals := make([]*ethpb.Validator, 0, count)

View File

@@ -22,6 +22,9 @@ func init() {
// Initialize the composite arrays.
fieldMap[types.Eth1DataVotes] = types.CompositeArray
fieldMap[types.Validators] = types.CompositeArray
// Initialize Compressed Arrays
fieldMap[types.Balances] = types.CompressedArray
}
// fieldMap keeps track of each field

View File

@@ -52,6 +52,7 @@ type Flags struct {
EnableHistoricalSpaceRepresentation bool // EnableHistoricalSpaceRepresentation enables the saving of registry validators in separate buckets to save space
EnableGetBlockOptimizations bool // EnableGetBlockOptimizations optimizes some elements of the GetBlock() function.
EnableBatchVerification bool // EnableBatchVerification enables batch signature verification on gossip messages.
EnableBalanceTrieComputation bool // EnableBalanceTrieComputation enables our beacon state to use balance tries for hash tree root operations.
// Logging related toggles.
DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected.
@@ -223,6 +224,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
logEnabled(enableBatchGossipVerification)
cfg.EnableBatchVerification = true
}
if ctx.Bool(enableBalanceTrieComputation.Name) {
logEnabled(enableBalanceTrieComputation)
cfg.EnableBalanceTrieComputation = true
}
Init(cfg)
}

View File

@@ -139,6 +139,10 @@ var (
Name: "enable-batch-gossip-verification",
Usage: "This enables batch verification of signatures received over gossip.",
}
enableBalanceTrieComputation = &cli.BoolFlag{
Name: "enable-balance-trie-computation",
Usage: "This enables optimized hash tree root operations for our balance field.",
}
)
// devModeFlags holds list of flags that are set when development mode is on.
@@ -147,6 +151,7 @@ var devModeFlags = []cli.Flag{
forceOptMaxCoverAggregationStategy,
enableGetBlockOptimizations,
enableBatchGossipVerification,
enableBalanceTrieComputation,
}
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
@@ -192,6 +197,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
disableCorrectlyPruneCanonicalAtts,
disableActiveBalanceCache,
enableBatchGossipVerification,
enableBalanceTrieComputation,
}...)
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.

View File

@@ -8,6 +8,7 @@ import (
"github.com/minio/sha256-simd"
"github.com/pkg/errors"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
)
const bytesPerChunk = 32
@@ -113,6 +114,53 @@ func Pack(serializedItems [][]byte) ([][]byte, error) {
return chunks, nil
}
// PackByChunk a given byte array's final chunk with zeroes if needed.
func PackByChunk(serializedItems [][]byte) ([][bytesPerChunk]byte, error) {
emptyChunk := [bytesPerChunk]byte{}
// If there are no items, we return an empty chunk.
if len(serializedItems) == 0 {
return [][bytesPerChunk]byte{emptyChunk}, nil
} else if len(serializedItems[0]) == bytesPerChunk {
// If each item has exactly BYTES_PER_CHUNK length, we return the list of serialized items.
chunks := make([][bytesPerChunk]byte, 0, len(serializedItems))
for _, c := range serializedItems {
chunks = append(chunks, bytesutil.ToBytes32(c))
}
return chunks, nil
}
// We flatten the list in order to pack its items into byte chunks correctly.
var orderedItems []byte
for _, item := range serializedItems {
orderedItems = append(orderedItems, item...)
}
// If all our serialized item slices are length zero, we
// exit early.
if len(orderedItems) == 0 {
return [][bytesPerChunk]byte{emptyChunk}, nil
}
numItems := len(orderedItems)
var chunks [][bytesPerChunk]byte
for i := 0; i < numItems; i += bytesPerChunk {
j := i + bytesPerChunk
// We create our upper bound index of the chunk, if it is greater than numItems,
// we set it as numItems itself.
if j > numItems {
j = numItems
}
// We create chunks from the list of items based on the
// indices determined above.
// Right-pad the last chunk with zero bytes if it does not
// have length bytesPerChunk from the helper.
// The ToBytes32 helper allocates a 32-byte array, before
// copying the ordered items in. This ensures that even if
// the last chunk is != 32 in length, we will right-pad it with
// zero bytes.
chunks = append(chunks, bytesutil.ToBytes32(orderedItems[i:j]))
}
return chunks, nil
}
// MixInLength appends hash length to root
func MixInLength(root [32]byte, length []byte) [32]byte {
var hash [32]byte

View File

@@ -94,6 +94,22 @@ func TestPack(t *testing.T) {
}
}
func TestPackByChunk(t *testing.T) {
byteSlice2D := [][]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7},
{1, 1, 2, 3, 5, 8, 13, 21, 34},
}
expected := [][32]byte{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 2, 5, 2, 6, 2, 7, 1, 1},
{2, 3, 5, 8, 13, 21, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}
result, err := ssz.PackByChunk(byteSlice2D)
require.NoError(t, err)
assert.Equal(t, len(expected), len(result))
for i, v := range expected {
assert.DeepEqual(t, v, result[i])
}
}
func TestMixInLength(t *testing.T) {
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
length := []byte{1, 2, 3}

View File

@@ -5,6 +5,7 @@ go_test(
size = "small",
srcs = [
"effective_balance_updates_test.go",
"epoch_processing_test.go",
"eth1_data_reset_test.go",
"historical_roots_update_test.go",
"inactivity_updates_test.go",
@@ -21,5 +22,8 @@ go_test(
],
shard_count = 4,
tags = ["spectest"],
deps = ["//testing/spectest/shared/altair/epoch_processing:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/altair/epoch_processing:go_default_library",
],
)

View File

@@ -0,0 +1,13 @@
package epoch_processing
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}

View File

@@ -8,5 +8,8 @@ go_test(
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/altair/sanity:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/altair/sanity:go_default_library",
],
)

View File

@@ -3,9 +3,16 @@ package random
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/sanity"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestMainnet_Altair_Random(t *testing.T) {
sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests")
}

View File

@@ -8,5 +8,8 @@ go_test(
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/altair/rewards:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/altair/rewards:go_default_library",
],
)

View File

@@ -3,9 +3,16 @@ package rewards
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/testing/spectest/shared/altair/rewards"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestMainnet_Altair_Rewards(t *testing.T) {
rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet")
}

View File

@@ -5,11 +5,15 @@ go_test(
size = "medium",
srcs = [
"blocks_test.go",
"sanity_test.go",
"slots_test.go",
],
data = glob(["*.yaml"]) + [
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/altair/sanity:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/altair/sanity:go_default_library",
],
)

View File

@@ -0,0 +1,13 @@
package sanity
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}

View File

@@ -22,6 +22,7 @@ go_test(
shard_count = 4,
tags = ["spectest"],
deps = [
"//config/features:go_default_library",
"//config/params:go_default_library",
"//testing/spectest/shared/phase0/epoch_processing:go_default_library",
],

View File

@@ -3,6 +3,7 @@ package epoch_processing
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
)
@@ -12,6 +13,8 @@ func TestMain(m *testing.M) {
c := params.BeaconConfig()
c.MinGenesisActiveValidatorCount = 16384
params.OverrideBeaconConfig(c)
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}

View File

@@ -8,5 +8,8 @@ go_test(
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/phase0/sanity:go_default_library",
],
)

View File

@@ -3,9 +3,16 @@ package random
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/sanity"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestMainnet_Phase0_Random(t *testing.T) {
sanity.RunBlockProcessingTest(t, "mainnet", "random/random/pyspec_tests")
}

View File

@@ -8,5 +8,8 @@ go_test(
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/phase0/rewards:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/phase0/rewards:go_default_library",
],
)

View File

@@ -3,9 +3,16 @@ package rewards
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/testing/spectest/shared/phase0/rewards"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}
func TestMainnet_Phase0_Rewards(t *testing.T) {
rewards.RunPrecomputeRewardsAndPenaltiesTests(t, "mainnet")
}

View File

@@ -5,11 +5,15 @@ go_test(
size = "medium",
srcs = [
"blocks_test.go",
"sanity_test.go",
"slots_test.go",
],
data = glob(["*.yaml"]) + [
"@consensus_spec_tests_mainnet//:test_data",
],
tags = ["spectest"],
deps = ["//testing/spectest/shared/phase0/sanity:go_default_library"],
deps = [
"//config/features:go_default_library",
"//testing/spectest/shared/phase0/sanity:go_default_library",
],
)

View File

@@ -0,0 +1,13 @@
package sanity
import (
"testing"
"github.com/prysmaticlabs/prysm/config/features"
)
func TestMain(m *testing.M) {
resetCfg := features.InitWithReset(&features.Flags{EnableBalanceTrieComputation: true})
defer resetCfg()
m.Run()
}