mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 23:18:15 -05:00
Integration of Vectorized Sha256 In Prysm (#10166)
* add changes * fix for vectorize * fix bug * add new bench * use new algorithms * add latest updates * save progress * hack even more * add more changes * change library * go mod * fix deps * fix dumb bug * add flag and remove redundant code * clean up better * remove those ones * clean up benches * clean up benches * cleanup * gaz * revert change * potuz's review * potuz's review * potuz's review * gaz * potuz's review * remove cyclical import * revert ide changes * potuz's review * return
This commit is contained in:
@@ -258,3 +258,21 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) {
|
||||
_, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey))
|
||||
assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey")
|
||||
}
|
||||
|
||||
func BenchmarkBeaconState(b *testing.B) {
|
||||
testState, _ := util.DeterministicGenesisState(b, 16000)
|
||||
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
|
||||
require.NoError(b, err)
|
||||
|
||||
b.Run("Vectorized SHA256", func(b *testing.B) {
|
||||
st, err := v1.InitializeFromProtoUnsafe(pbState)
|
||||
require.NoError(b, err)
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(b, err)
|
||||
})
|
||||
|
||||
b.Run("Current SHA256", func(b *testing.B) {
|
||||
_, err := pbState.HashTreeRoot()
|
||||
require.NoError(b, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -32,10 +32,12 @@ go_library(
|
||||
],
|
||||
deps = [
|
||||
"//beacon-chain/core/transition/stateutils:go_default_library",
|
||||
"//config/features:go_default_library",
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//container/trie:go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//crypto/hash/htr:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//math:go_default_library",
|
||||
@@ -51,6 +53,7 @@ go_test(
|
||||
srcs = [
|
||||
"benchmark_test.go",
|
||||
"field_root_test.go",
|
||||
"field_root_validator_test.go",
|
||||
"reference_bench_test.go",
|
||||
"state_root_test.go",
|
||||
"trie_helpers_test.go",
|
||||
@@ -58,11 +61,14 @@ go_test(
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
"//beacon-chain/state:go_default_library",
|
||||
"//config/features:go_default_library",
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//math:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//runtime/interop:go_default_library",
|
||||
"//testing/assert:go_default_library",
|
||||
|
||||
@@ -5,12 +5,25 @@ import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
)
|
||||
|
||||
const (
|
||||
// number of field roots for the validator object.
|
||||
validatorFieldRoots = 8
|
||||
|
||||
// Depth of tree representation of an individual
|
||||
// validator.
|
||||
// NumOfRoots = 2 ^ (TreeDepth)
|
||||
// 8 = 2 ^ 3
|
||||
validatorTreeDepth = 3
|
||||
)
|
||||
|
||||
// ValidatorRegistryRoot computes the HashTreeRoot Merkleization of
|
||||
// a list of validator structs according to the Ethereum
|
||||
// Simple Serialize specification.
|
||||
@@ -19,14 +32,20 @@ func ValidatorRegistryRoot(vals []*ethpb.Validator) ([32]byte, error) {
|
||||
}
|
||||
|
||||
func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
|
||||
roots := make([][32]byte, len(validators))
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
for i := 0; i < len(validators); i++ {
|
||||
val, err := validatorRoot(hasher, validators[i])
|
||||
|
||||
var err error
|
||||
var roots [][32]byte
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
roots, err = optimizedValidatorRoots(validators)
|
||||
if err != nil {
|
||||
return [32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
||||
return [32]byte{}, err
|
||||
}
|
||||
} else {
|
||||
roots, err = validatorRoots(hasher, validators)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
roots[i] = val
|
||||
}
|
||||
|
||||
validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit)
|
||||
@@ -45,6 +64,42 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func validatorRoots(hasher func([]byte) [32]byte, validators []*ethpb.Validator) ([][32]byte, error) {
|
||||
roots := make([][32]byte, len(validators))
|
||||
for i := 0; i < len(validators); i++ {
|
||||
val, err := validatorRoot(hasher, validators[i])
|
||||
if err != nil {
|
||||
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
||||
}
|
||||
roots[i] = val
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
func optimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) {
|
||||
roots := make([][32]byte, 0, len(validators)*validatorFieldRoots)
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
for i := 0; i < len(validators); i++ {
|
||||
fRoots, err := ValidatorFieldRoots(hasher, validators[i])
|
||||
if err != nil {
|
||||
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
||||
}
|
||||
roots = append(roots, fRoots...)
|
||||
}
|
||||
|
||||
// A validator's tree can represented with a depth of 3. As log2(8) = 3
|
||||
// Using this property we can lay out all the individual fields of a
|
||||
// validator and hash them in single level using our vectorized routine.
|
||||
for i := 0; i < validatorTreeDepth; i++ {
|
||||
// Overwrite input lists as we are hashing by level
|
||||
// and only need the highest level to proceed.
|
||||
outputLen := len(roots) / 2
|
||||
htr.VectorizedSha256(roots, roots)
|
||||
roots = roots[:outputLen]
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
||||
if validator == nil {
|
||||
return [32]byte{}, errors.New("nil validator")
|
||||
|
||||
32
beacon-chain/state/stateutil/field_root_validator_test.go
Normal file
32
beacon-chain/state/stateutil/field_root_validator_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package stateutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mathutil "github.com/prysmaticlabs/prysm/math"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
)
|
||||
|
||||
func TestValidatorConstants(t *testing.T) {
|
||||
v := ðpb.Validator{}
|
||||
refV := reflect.ValueOf(v).Elem()
|
||||
numFields := refV.NumField()
|
||||
numOfValFields := 0
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
if strings.Contains(refV.Type().Field(i).Name, "state") ||
|
||||
strings.Contains(refV.Type().Field(i).Name, "sizeCache") ||
|
||||
strings.Contains(refV.Type().Field(i).Name, "unknownFields") {
|
||||
continue
|
||||
}
|
||||
numOfValFields++
|
||||
}
|
||||
assert.Equal(t, validatorFieldRoots, numOfValFields)
|
||||
assert.Equal(t, uint64(validatorFieldRoots), mathutil.PowerOf2(validatorTreeDepth))
|
||||
|
||||
_, err := ValidatorRegistryRoot([]*ethpb.Validator{v})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/container/trie"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/math"
|
||||
)
|
||||
@@ -61,25 +63,43 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte {
|
||||
layers[0] = transformedLeaves
|
||||
buffer := bytes.NewBuffer([]byte{})
|
||||
buffer.Grow(64)
|
||||
|
||||
for i := 0; i < int(depth); i++ {
|
||||
oddNodeLength := len(layers[i])%2 == 1
|
||||
if oddNodeLength {
|
||||
zerohash := trie.ZeroHashes[i]
|
||||
layers[i] = append(layers[i], &zerohash)
|
||||
layerLen := len(layers[i])
|
||||
oddNodeLength := layerLen%2 == 1
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
if oddNodeLength {
|
||||
zerohash := trie.ZeroHashes[i]
|
||||
elements = append(elements, zerohash)
|
||||
layerLen++
|
||||
}
|
||||
|
||||
layers[i+1] = make([]*[32]byte, layerLen/2)
|
||||
newElems := make([][32]byte, layerLen/2)
|
||||
htr.VectorizedSha256(elements, newElems)
|
||||
elements = newElems
|
||||
for j := range elements {
|
||||
layers[i+1][j] = &elements[j]
|
||||
}
|
||||
} else {
|
||||
if oddNodeLength {
|
||||
zerohash := trie.ZeroHashes[i]
|
||||
layers[i] = append(layers[i], &zerohash)
|
||||
}
|
||||
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
|
||||
for j := 0; j < len(layers[i]); j += 2 {
|
||||
buffer.Write(layers[i][j][:])
|
||||
buffer.Write(layers[i][j+1][:])
|
||||
concat := hasher(buffer.Bytes())
|
||||
updatedValues = append(updatedValues, &concat)
|
||||
buffer.Reset()
|
||||
}
|
||||
// remove zerohash node from tree
|
||||
if oddNodeLength {
|
||||
layers[i] = layers[i][:len(layers[i])-1]
|
||||
}
|
||||
layers[i+1] = updatedValues
|
||||
}
|
||||
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
|
||||
for j := 0; j < len(layers[i]); j += 2 {
|
||||
buffer.Write(layers[i][j][:])
|
||||
buffer.Write(layers[i][j+1][:])
|
||||
concat := hasher(buffer.Bytes())
|
||||
updatedValues = append(updatedValues, &concat)
|
||||
buffer.Reset()
|
||||
}
|
||||
// remove zerohash node from tree
|
||||
if oddNodeLength {
|
||||
layers[i] = layers[i][:len(layers[i])-1]
|
||||
}
|
||||
layers[i+1] = updatedValues
|
||||
}
|
||||
return layers
|
||||
}
|
||||
@@ -277,18 +297,24 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte,
|
||||
chunkBuffer := bytes.NewBuffer([]byte{})
|
||||
chunkBuffer.Grow(64)
|
||||
for len(hashLayer) > 1 && i < len(layers) {
|
||||
layer := make([][32]byte, len(hashLayer)/2)
|
||||
if !math.IsPowerOf2(uint64(len(hashLayer))) {
|
||||
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
|
||||
}
|
||||
for j := 0; j < len(hashLayer); j += 2 {
|
||||
chunkBuffer.Write(hashLayer[j][:])
|
||||
chunkBuffer.Write(hashLayer[j+1][:])
|
||||
hashedChunk := hasher(chunkBuffer.Bytes())
|
||||
layer[j/2] = hashedChunk
|
||||
chunkBuffer.Reset()
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
newLayer := make([][32]byte, len(hashLayer)/2)
|
||||
htr.VectorizedSha256(hashLayer, newLayer)
|
||||
hashLayer = newLayer
|
||||
} else {
|
||||
layer := make([][32]byte, len(hashLayer)/2)
|
||||
for j := 0; j < len(hashLayer); j += 2 {
|
||||
chunkBuffer.Write(hashLayer[j][:])
|
||||
chunkBuffer.Write(hashLayer[j+1][:])
|
||||
hashedChunk := hasher(chunkBuffer.Bytes())
|
||||
layer[j/2] = hashedChunk
|
||||
chunkBuffer.Reset()
|
||||
}
|
||||
hashLayer = layer
|
||||
}
|
||||
hashLayer = layer
|
||||
layers[i] = hashLayer
|
||||
i++
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"testing"
|
||||
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
@@ -18,15 +20,56 @@ func TestReturnTrieLayer_OK(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
require.NoError(t, err)
|
||||
blockRts := newState.BlockRoots()
|
||||
roots := make([][32]byte, 0, len(blockRts))
|
||||
for _, rt := range blockRts {
|
||||
roots = append(roots, bytesutil.ToBytes32(rt))
|
||||
}
|
||||
roots := retrieveBlockRoots(newState)
|
||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||
assert.NoError(t, err)
|
||||
newRoot := *layers[len(layers)-1][0]
|
||||
assert.Equal(t, root, newRoot)
|
||||
|
||||
flags := &features.Flags{}
|
||||
flags.EnableVectorizedHTR = true
|
||||
reset := features.InitWithReset(flags)
|
||||
defer reset()
|
||||
|
||||
layers, err = stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||
assert.NoError(t, err)
|
||||
lastRoot := *layers[len(layers)-1][0]
|
||||
assert.Equal(t, root, lastRoot)
|
||||
}
|
||||
|
||||
func BenchmarkReturnTrieLayer_NormalAlgorithm(b *testing.B) {
|
||||
newState, _ := util.DeterministicGenesisState(b, 32)
|
||||
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
require.NoError(b, err)
|
||||
roots := retrieveBlockRoots(newState)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||
assert.NoError(b, err)
|
||||
newRoot := *layers[len(layers)-1][0]
|
||||
assert.Equal(b, root, newRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReturnTrieLayer_VectorizedAlgorithm(b *testing.B) {
|
||||
flags := &features.Flags{}
|
||||
flags.EnableVectorizedHTR = true
|
||||
reset := features.InitWithReset(flags)
|
||||
defer reset()
|
||||
|
||||
newState, _ := util.DeterministicGenesisState(b, 32)
|
||||
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
require.NoError(b, err)
|
||||
roots := retrieveBlockRoots(newState)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||
assert.NoError(b, err)
|
||||
newRoot := *layers[len(layers)-1][0]
|
||||
assert.Equal(b, root, newRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnTrieLayerVariable_OK(t *testing.T) {
|
||||
@@ -46,15 +89,73 @@ func TestReturnTrieLayerVariable_OK(t *testing.T) {
|
||||
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, root, newRoot)
|
||||
|
||||
flags := &features.Flags{}
|
||||
flags.EnableVectorizedHTR = true
|
||||
reset := features.InitWithReset(flags)
|
||||
defer reset()
|
||||
|
||||
layers = stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||
lastRoot := *layers[len(layers)-1][0]
|
||||
lastRoot, err = stateutil.AddInMixin(lastRoot, uint64(len(validators)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, root, lastRoot)
|
||||
|
||||
}
|
||||
|
||||
func BenchmarkReturnTrieLayerVariable_NormalAlgorithm(b *testing.B) {
|
||||
newState, _ := util.DeterministicGenesisState(b, 16000)
|
||||
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
|
||||
require.NoError(b, err)
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
validators := newState.Validators()
|
||||
roots := make([][32]byte, 0, len(validators))
|
||||
for _, val := range validators {
|
||||
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
|
||||
require.NoError(b, err)
|
||||
roots = append(roots, rt)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||
newRoot := *layers[len(layers)-1][0]
|
||||
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||
require.NoError(b, err)
|
||||
assert.Equal(b, root, newRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReturnTrieLayerVariable_VectorizedAlgorithm(b *testing.B) {
|
||||
flags := &features.Flags{}
|
||||
flags.EnableVectorizedHTR = true
|
||||
reset := features.InitWithReset(flags)
|
||||
defer reset()
|
||||
|
||||
newState, _ := util.DeterministicGenesisState(b, 16000)
|
||||
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
|
||||
require.NoError(b, err)
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
validators := newState.Validators()
|
||||
roots := make([][32]byte, 0, len(validators))
|
||||
for _, val := range validators {
|
||||
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
|
||||
require.NoError(b, err)
|
||||
roots = append(roots, rt)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||
newRoot := *layers[len(layers)-1][0]
|
||||
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||
require.NoError(b, err)
|
||||
assert.Equal(b, root, newRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
blockRts := newState.BlockRoots()
|
||||
roots := make([][32]byte, 0, len(blockRts))
|
||||
for _, rt := range blockRts {
|
||||
roots = append(roots, bytesutil.ToBytes32(rt))
|
||||
}
|
||||
roots := retrieveBlockRoots(newState)
|
||||
|
||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -120,3 +221,12 @@ func TestMerkleizeTrieLeaves_BadHashLayer(t *testing.T) {
|
||||
})
|
||||
assert.ErrorContains(t, "hash layer is a non power of 2", err)
|
||||
}
|
||||
|
||||
func retrieveBlockRoots(b state.BeaconState) [][32]byte {
|
||||
blockRts := b.BlockRoots()
|
||||
roots := make([][32]byte, 0, len(blockRts))
|
||||
for _, rt := range blockRts {
|
||||
roots = append(roots, bytesutil.ToBytes32(rt))
|
||||
}
|
||||
return roots
|
||||
}
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
@@ -14,6 +16,16 @@ import (
|
||||
// ValidatorRootWithHasher describes a method from which the hash tree root
|
||||
// of a validator is returned.
|
||||
func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
||||
fieldRoots, err := ValidatorFieldRoots(hasher, validator)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
||||
}
|
||||
|
||||
// ValidatorFieldRoots describes a method from which the hash tree root
|
||||
// of a validator is returned.
|
||||
func ValidatorFieldRoots(hasher ssz.HashFn, validator *ethpb.Validator) ([][32]byte, error) {
|
||||
var fieldRoots [][32]byte
|
||||
if validator != nil {
|
||||
pubkey := bytesutil.ToBytes48(validator.PublicKey)
|
||||
@@ -40,18 +52,25 @@ func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32
|
||||
binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch))
|
||||
|
||||
// Public key.
|
||||
pubKeyChunks, err := ssz.Pack([][]byte{pubkey[:]})
|
||||
pubKeyChunks, err := ssz.PackByChunk([][]byte{pubkey[:]})
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
return [][32]byte{}, err
|
||||
}
|
||||
pubKeyRoot, err := ssz.BitwiseMerkleize(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
var pubKeyRoot [32]byte
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
outputChunk := make([][32]byte, 1)
|
||||
htr.VectorizedSha256(pubKeyChunks, outputChunk)
|
||||
pubKeyRoot = outputChunk[0]
|
||||
} else {
|
||||
pubKeyRoot, err = ssz.BitwiseMerkleizeArrays(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
|
||||
if err != nil {
|
||||
return [][32]byte{}, err
|
||||
}
|
||||
}
|
||||
fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf,
|
||||
activationBuf, exitBuf, withdrawalBuf}
|
||||
}
|
||||
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
||||
return fieldRoots, nil
|
||||
}
|
||||
|
||||
// Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of
|
||||
|
||||
@@ -177,6 +177,24 @@ func TestBeaconState_HashTreeRoot(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBeaconState(b *testing.B) {
|
||||
testState, _ := util.DeterministicGenesisState(b, 16000)
|
||||
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
|
||||
require.NoError(b, err)
|
||||
|
||||
b.Run("Vectorized SHA256", func(b *testing.B) {
|
||||
st, err := v1.InitializeFromProtoUnsafe(pbState)
|
||||
require.NoError(b, err)
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(b, err)
|
||||
})
|
||||
|
||||
b.Run("Current SHA256", func(b *testing.B) {
|
||||
_, err := pbState.HashTreeRoot()
|
||||
require.NoError(b, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
|
||||
testState, _ := util.DeterministicGenesisState(t, 64)
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ go_library(
|
||||
"//crypto/bls:go_default_library",
|
||||
"//crypto/rand:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//encoding/ssz/equality:go_default_library",
|
||||
"//monitoring/tracing:go_default_library",
|
||||
"//network/forks:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
@@ -195,7 +195,7 @@ go_test(
|
||||
"//crypto/bls:go_default_library",
|
||||
"//crypto/rand:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//encoding/ssz/equality:go_default_library",
|
||||
"//network/forks:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//proto/prysm/v1alpha1/attestation:go_default_library",
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/crypto/rand"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
"github.com/prysmaticlabs/prysm/monitoring/tracing"
|
||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block"
|
||||
"github.com/prysmaticlabs/prysm/time/slots"
|
||||
@@ -328,7 +328,7 @@ func (s *Service) deleteBlockFromPendingQueue(slot types.Slot, b block.SignedBea
|
||||
|
||||
newBlks := make([]block.SignedBeaconBlock, 0, len(blks))
|
||||
for _, blk := range blks {
|
||||
if ssz.DeepEqual(blk.Proto(), b.Proto()) {
|
||||
if equality.DeepEqual(blk.Proto(), b.Proto()) {
|
||||
continue
|
||||
}
|
||||
newBlks = append(newBlks, blk)
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/p2p"
|
||||
p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata"
|
||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
||||
@@ -125,7 +125,7 @@ func TestMetadataRPCHandler_SendsMetadata(t *testing.T) {
|
||||
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||
t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func TestMetadataRPCHandler_SendsMetadataAltair(t *testing.T) {
|
||||
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||
t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,8 @@ type Flags struct {
|
||||
CorrectlyInsertOrphanedAtts bool
|
||||
CorrectlyPruneCanonicalAtts bool
|
||||
|
||||
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
|
||||
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
|
||||
EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines.
|
||||
|
||||
// KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have
|
||||
// changed on disk. This feature is for advanced use cases only.
|
||||
@@ -222,6 +223,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
|
||||
logEnabled(enableNativeState)
|
||||
cfg.EnableNativeState = true
|
||||
}
|
||||
if ctx.Bool(enableVecHTR.Name) {
|
||||
logEnabled(enableVecHTR)
|
||||
cfg.EnableVectorizedHTR = true
|
||||
}
|
||||
Init(cfg)
|
||||
}
|
||||
|
||||
|
||||
@@ -134,11 +134,16 @@ var (
|
||||
Name: "enable-native-state",
|
||||
Usage: "Enables representing the beacon state as a pure Go struct.",
|
||||
}
|
||||
enableVecHTR = &cli.BoolFlag{
|
||||
Name: "enable-vectorized-htr",
|
||||
Usage: "Enables new go sha256 library which utilizes optimized routines for merkle trees",
|
||||
}
|
||||
)
|
||||
|
||||
// devModeFlags holds list of flags that are set when development mode is on.
|
||||
var devModeFlags = []cli.Flag{
|
||||
enablePeerScorer,
|
||||
enableVecHTR,
|
||||
}
|
||||
|
||||
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
|
||||
@@ -183,6 +188,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
|
||||
disableBatchGossipVerification,
|
||||
disableBalanceTrieComputation,
|
||||
enableNativeState,
|
||||
enableVecHTR,
|
||||
}...)
|
||||
|
||||
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.
|
||||
|
||||
9
crypto/hash/htr/BUILD.bazel
Normal file
9
crypto/hash/htr/BUILD.bazel
Normal file
@@ -0,0 +1,9 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["hashtree.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/crypto/hash/htr",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"],
|
||||
)
|
||||
17
crypto/hash/htr/hashtree.go
Normal file
17
crypto/hash/htr/hashtree.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package htr
|
||||
|
||||
import (
|
||||
"github.com/prysmaticlabs/gohashtree"
|
||||
)
|
||||
|
||||
// VectorizedSha256 takes a list of roots and hashes them using CPU
|
||||
// specific vector instructions. Depending on host machine's specific
|
||||
// hardware configuration, using this routine can lead to a significant
|
||||
// performance improvement compared to the default method of hashing
|
||||
// lists.
|
||||
func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) {
|
||||
err := gohashtree.Hash(outputList, inputList)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
7
deps.bzl
7
deps.bzl
@@ -2982,6 +2982,13 @@ def prysm_deps():
|
||||
sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=",
|
||||
version = "v0.0.0-20210809151128-385d8c5e3fb7",
|
||||
)
|
||||
go_repository(
|
||||
name = "com_github_prysmaticlabs_gohashtree",
|
||||
importpath = "github.com/prysmaticlabs/gohashtree",
|
||||
sum = "h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=",
|
||||
version = "v0.0.0-20220208111633-0606f58df32f",
|
||||
)
|
||||
|
||||
go_repository(
|
||||
name = "com_github_prysmaticlabs_prombbolt",
|
||||
importpath = "github.com/prysmaticlabs/prombbolt",
|
||||
|
||||
@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"deep_equal.go",
|
||||
"hashers.go",
|
||||
"helpers.go",
|
||||
"htrutils.go",
|
||||
@@ -12,16 +11,16 @@ go_library(
|
||||
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//config/features:go_default_library",
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//container/trie:go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//crypto/hash/htr:go_default_library",
|
||||
"//encoding/bytesutil:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"@com_github_minio_sha256_simd//:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||
"@org_golang_google_protobuf//proto:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -29,7 +28,6 @@ go_test(
|
||||
name = "go_default_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"deep_equal_test.go",
|
||||
"hashers_test.go",
|
||||
"helpers_test.go",
|
||||
"htrutils_test.go",
|
||||
|
||||
22
encoding/ssz/equality/BUILD.bazel
Normal file
22
encoding/ssz/equality/BUILD.bazel
Normal file
@@ -0,0 +1,22 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["deep_equal.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz/equality",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
||||
"@org_golang_google_protobuf//proto:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["deep_equal_test.go"],
|
||||
deps = [
|
||||
":go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//testing/assert:go_default_library",
|
||||
],
|
||||
)
|
||||
@@ -1,4 +1,4 @@
|
||||
package ssz
|
||||
package equality
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
@@ -1,34 +1,34 @@
|
||||
package ssz_test
|
||||
package equality_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
)
|
||||
|
||||
func TestDeepEqualBasicTypes(t *testing.T) {
|
||||
assert.Equal(t, true, ssz.DeepEqual(true, true))
|
||||
assert.Equal(t, false, ssz.DeepEqual(true, false))
|
||||
assert.Equal(t, true, equality.DeepEqual(true, true))
|
||||
assert.Equal(t, false, equality.DeepEqual(true, false))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222)))
|
||||
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111)))
|
||||
assert.Equal(t, true, equality.DeepEqual(byte(222), byte(222)))
|
||||
assert.Equal(t, false, equality.DeepEqual(byte(222), byte(111)))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890)))
|
||||
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210)))
|
||||
assert.Equal(t, true, equality.DeepEqual(uint64(1234567890), uint64(1234567890)))
|
||||
assert.Equal(t, false, equality.DeepEqual(uint64(1234567890), uint64(987653210)))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual("hello", "hello"))
|
||||
assert.Equal(t, false, ssz.DeepEqual("hello", "world"))
|
||||
assert.Equal(t, true, equality.DeepEqual("hello", "hello"))
|
||||
assert.Equal(t, false, equality.DeepEqual("hello", "world"))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
|
||||
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
|
||||
assert.Equal(t, true, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
|
||||
assert.Equal(t, false, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
|
||||
|
||||
var nilSlice1, nilSlice2 []byte
|
||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2))
|
||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{}))
|
||||
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
|
||||
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
|
||||
assert.Equal(t, true, equality.DeepEqual(nilSlice1, nilSlice2))
|
||||
assert.Equal(t, true, equality.DeepEqual(nilSlice1, []byte{}))
|
||||
assert.Equal(t, true, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
|
||||
assert.Equal(t, false, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
|
||||
}
|
||||
|
||||
func TestDeepEqualStructs(t *testing.T) {
|
||||
@@ -39,8 +39,8 @@ func TestDeepEqualStructs(t *testing.T) {
|
||||
store1 := Store{uint64(1234), nil}
|
||||
store2 := Store{uint64(1234), []byte{}}
|
||||
store3 := Store{uint64(4321), []byte{}}
|
||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
||||
assert.Equal(t, true, equality.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, equality.DeepEqual(store1, store3))
|
||||
}
|
||||
|
||||
func TestDeepEqualStructs_Unexported(t *testing.T) {
|
||||
@@ -53,14 +53,14 @@ func TestDeepEqualStructs_Unexported(t *testing.T) {
|
||||
store2 := Store{uint64(1234), []byte{}, "hi there"}
|
||||
store3 := Store{uint64(4321), []byte{}, "wow"}
|
||||
store4 := Store{uint64(4321), []byte{}, "bow wow"}
|
||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store3, store4))
|
||||
assert.Equal(t, true, equality.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, equality.DeepEqual(store1, store3))
|
||||
assert.Equal(t, false, equality.DeepEqual(store3, store4))
|
||||
}
|
||||
|
||||
func TestDeepEqualProto(t *testing.T) {
|
||||
var fork1, fork2 *ethpb.Fork
|
||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2))
|
||||
assert.Equal(t, true, equality.DeepEqual(fork1, fork2))
|
||||
|
||||
fork1 = ðpb.Fork{
|
||||
PreviousVersion: []byte{123},
|
||||
@@ -72,8 +72,8 @@ func TestDeepEqualProto(t *testing.T) {
|
||||
CurrentVersion: []byte{125},
|
||||
Epoch: 1234567890,
|
||||
}
|
||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1))
|
||||
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2))
|
||||
assert.Equal(t, true, equality.DeepEqual(fork1, fork1))
|
||||
assert.Equal(t, false, equality.DeepEqual(fork1, fork2))
|
||||
|
||||
checkpoint1 := ðpb.Checkpoint{
|
||||
Epoch: 1234567890,
|
||||
@@ -83,7 +83,7 @@ func TestDeepEqualProto(t *testing.T) {
|
||||
Epoch: 1234567890,
|
||||
Root: nil,
|
||||
}
|
||||
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2))
|
||||
assert.Equal(t, true, equality.DeepEqual(checkpoint1, checkpoint2))
|
||||
}
|
||||
|
||||
func Test_IsProto(t *testing.T) {
|
||||
@@ -125,7 +125,7 @@ func Test_IsProto(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ssz.IsProto(tt.item); got != tt.want {
|
||||
if got := equality.IsProto(tt.item); got != tt.want {
|
||||
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/minio/sha256-simd"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/config/features"
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
)
|
||||
|
||||
@@ -50,6 +51,9 @@ func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]
|
||||
if count > limit {
|
||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||
}
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
return MerkleizeList(chunks, limit), nil
|
||||
}
|
||||
hashFn := NewHasherFunc(hasher)
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
@@ -62,6 +66,9 @@ func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint6
|
||||
if count > limit {
|
||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||
}
|
||||
if features.Get().EnableVectorizedHTR {
|
||||
return MerkleizeVector(chunks, limit), nil
|
||||
}
|
||||
hashFn := NewHasherFunc(hasher)
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i][:]
|
||||
|
||||
@@ -2,6 +2,7 @@ package ssz
|
||||
|
||||
import (
|
||||
"github.com/prysmaticlabs/prysm/container/trie"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||
)
|
||||
|
||||
// Merkleize.go is mostly a directly copy of the same filename from
|
||||
@@ -196,3 +197,40 @@ func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []by
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MerkleizeVector uses our optimized routine to hash a list of 32-byte
|
||||
// elements.
|
||||
func MerkleizeVector(elements [][32]byte, length uint64) [32]byte {
|
||||
depth := Depth(length)
|
||||
// Return zerohash at depth
|
||||
if len(elements) == 0 {
|
||||
return trie.ZeroHashes[depth]
|
||||
}
|
||||
for i := 0; i < int(depth); i++ {
|
||||
layerLen := len(elements)
|
||||
oddNodeLength := layerLen%2 == 1
|
||||
if oddNodeLength {
|
||||
zerohash := trie.ZeroHashes[i]
|
||||
elements = append(elements, zerohash)
|
||||
}
|
||||
outputLen := len(elements) / 2
|
||||
htr.VectorizedSha256(elements, elements)
|
||||
elements = elements[:outputLen]
|
||||
}
|
||||
return elements[0]
|
||||
}
|
||||
|
||||
// MerkleizeList uses our optimized routine to hash a 2d-list of
|
||||
// elements.
|
||||
func MerkleizeList(elements [][]byte, length uint64) [32]byte {
|
||||
depth := Depth(length)
|
||||
// Return zerohash at depth
|
||||
if len(elements) == 0 {
|
||||
return trie.ZeroHashes[depth]
|
||||
}
|
||||
newElems := make([][32]byte, len(elements))
|
||||
for i := range elements {
|
||||
copy(newElems[i][:], elements[i])
|
||||
}
|
||||
return MerkleizeVector(newElems, length)
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -253,6 +253,7 @@ require (
|
||||
github.com/holiman/uint256 v1.2.0
|
||||
github.com/peterh/liner v1.2.0 // indirect
|
||||
github.com/prometheus/tsdb v0.10.0 // indirect
|
||||
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f
|
||||
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
|
||||
google.golang.org/api v0.34.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1123,6 +1123,8 @@ github.com/prysmaticlabs/fastssz v0.0.0-20220110145812-fafb696cae88/go.mod h1:AS
|
||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
|
||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=
|
||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
|
||||
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=
|
||||
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f/go.mod h1:4pWaT30XoEx1j8KNJf3TV+E3mQkaufn7mf+jRNb/Fuk=
|
||||
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss=
|
||||
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
|
||||
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=
|
||||
|
||||
@@ -28,7 +28,7 @@ go_test(
|
||||
deps = [
|
||||
"//config/params:go_default_library",
|
||||
"//crypto/bls:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//encoding/ssz/equality:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//proto/prysm/v1alpha1/attestation/aggregation:go_default_library",
|
||||
"//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library",
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/config/params"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation"
|
||||
aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing"
|
||||
@@ -48,7 +48,7 @@ func TestAggregateAttestations_AggregatePair(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
got, err := AggregatePair(tt.a1, tt.a2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, ssz.DeepEqual(got, tt.want))
|
||||
require.Equal(t, true, equality.DeepEqual(got, tt.want))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ go_library(
|
||||
importpath = "github.com/prysmaticlabs/prysm/testing/assertions",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//encoding/ssz/equality:go_default_library",
|
||||
"@com_github_d4l3k_messagediff//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
|
||||
"@org_golang_google_protobuf//proto:go_default_library",
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/messagediff"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -61,7 +61,7 @@ func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
|
||||
|
||||
// DeepSSZEqual compares values using ssz.DeepEqual.
|
||||
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
||||
if !ssz.DeepEqual(expected, actual) {
|
||||
if !equality.DeepEqual(expected, actual) {
|
||||
errMsg := parseMsg("Values are not equal", msg...)
|
||||
_, file, line, _ := runtime.Caller(2)
|
||||
diff, _ := messagediff.PrettyDiff(expected, actual)
|
||||
@@ -71,7 +71,7 @@ func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
|
||||
|
||||
// DeepNotSSZEqual compares values using ssz.DeepEqual.
|
||||
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
||||
if ssz.DeepEqual(expected, actual) {
|
||||
if equality.DeepEqual(expected, actual) {
|
||||
errMsg := parseMsg("Values are equal", msg...)
|
||||
_, file, line, _ := runtime.Caller(2)
|
||||
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
|
||||
|
||||
@@ -12,7 +12,7 @@ go_library(
|
||||
deps = [
|
||||
"//beacon-chain/core/transition:go_default_library",
|
||||
"//beacon-chain/state/v1:go_default_library",
|
||||
"//encoding/ssz:go_default_library",
|
||||
"//encoding/ssz/equality:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//proto/prysm/v1alpha1/wrapper:go_default_library",
|
||||
"//runtime/version:go_default_library",
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/kr/pretty"
|
||||
"github.com/prysmaticlabs/prysm/beacon-chain/core/transition"
|
||||
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
||||
"github.com/prysmaticlabs/prysm/runtime/version"
|
||||
@@ -191,7 +191,7 @@ func main() {
|
||||
if err := dataFetcher(expectedPostStatePath, expectedState); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if !ssz.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
|
||||
if !equality.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
|
||||
diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe())
|
||||
log.Errorf("Derived state differs from provided post state: %s", diff)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user