mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-09 23:48:06 -05:00
Integration of Vectorized Sha256 In Prysm (#10166)
* add changes * fix for vectorize * fix bug * add new bench * use new algorithms * add latest updates * save progress * hack even more * add more changes * change library * go mod * fix deps * fix dumb bug * add flag and remove redundant code * clean up better * remove those ones * clean up benches * clean up benches * cleanup * gaz * revert change * potuz's review * potuz's review * potuz's review * gaz * potuz's review * remove cyclical import * revert ide changes * potuz's review * return
This commit is contained in:
@@ -258,3 +258,21 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) {
|
|||||||
_, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey))
|
_, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey))
|
||||||
assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey")
|
assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkBeaconState(b *testing.B) {
|
||||||
|
testState, _ := util.DeterministicGenesisState(b, 16000)
|
||||||
|
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.Run("Vectorized SHA256", func(b *testing.B) {
|
||||||
|
st, err := v1.InitializeFromProtoUnsafe(pbState)
|
||||||
|
require.NoError(b, err)
|
||||||
|
_, err = st.HashTreeRoot(context.Background())
|
||||||
|
assert.NoError(b, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Current SHA256", func(b *testing.B) {
|
||||||
|
_, err := pbState.HashTreeRoot()
|
||||||
|
require.NoError(b, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -32,10 +32,12 @@ go_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//beacon-chain/core/transition/stateutils:go_default_library",
|
"//beacon-chain/core/transition/stateutils:go_default_library",
|
||||||
|
"//config/features:go_default_library",
|
||||||
"//config/fieldparams:go_default_library",
|
"//config/fieldparams:go_default_library",
|
||||||
"//config/params:go_default_library",
|
"//config/params:go_default_library",
|
||||||
"//container/trie:go_default_library",
|
"//container/trie:go_default_library",
|
||||||
"//crypto/hash:go_default_library",
|
"//crypto/hash:go_default_library",
|
||||||
|
"//crypto/hash/htr:go_default_library",
|
||||||
"//encoding/bytesutil:go_default_library",
|
"//encoding/bytesutil:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz:go_default_library",
|
||||||
"//math:go_default_library",
|
"//math:go_default_library",
|
||||||
@@ -51,6 +53,7 @@ go_test(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"benchmark_test.go",
|
"benchmark_test.go",
|
||||||
"field_root_test.go",
|
"field_root_test.go",
|
||||||
|
"field_root_validator_test.go",
|
||||||
"reference_bench_test.go",
|
"reference_bench_test.go",
|
||||||
"state_root_test.go",
|
"state_root_test.go",
|
||||||
"trie_helpers_test.go",
|
"trie_helpers_test.go",
|
||||||
@@ -58,11 +61,14 @@ go_test(
|
|||||||
],
|
],
|
||||||
embed = [":go_default_library"],
|
embed = [":go_default_library"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//beacon-chain/state:go_default_library",
|
||||||
|
"//config/features:go_default_library",
|
||||||
"//config/fieldparams:go_default_library",
|
"//config/fieldparams:go_default_library",
|
||||||
"//config/params:go_default_library",
|
"//config/params:go_default_library",
|
||||||
"//crypto/hash:go_default_library",
|
"//crypto/hash:go_default_library",
|
||||||
"//encoding/bytesutil:go_default_library",
|
"//encoding/bytesutil:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz:go_default_library",
|
||||||
|
"//math:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
"//runtime/interop:go_default_library",
|
"//runtime/interop:go_default_library",
|
||||||
"//testing/assert:go_default_library",
|
"//testing/assert:go_default_library",
|
||||||
|
|||||||
@@ -5,12 +5,25 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prysmaticlabs/prysm/config/features"
|
||||||
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
||||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||||
|
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// number of field roots for the validator object.
|
||||||
|
validatorFieldRoots = 8
|
||||||
|
|
||||||
|
// Depth of tree representation of an individual
|
||||||
|
// validator.
|
||||||
|
// NumOfRoots = 2 ^ (TreeDepth)
|
||||||
|
// 8 = 2 ^ 3
|
||||||
|
validatorTreeDepth = 3
|
||||||
|
)
|
||||||
|
|
||||||
// ValidatorRegistryRoot computes the HashTreeRoot Merkleization of
|
// ValidatorRegistryRoot computes the HashTreeRoot Merkleization of
|
||||||
// a list of validator structs according to the Ethereum
|
// a list of validator structs according to the Ethereum
|
||||||
// Simple Serialize specification.
|
// Simple Serialize specification.
|
||||||
@@ -19,14 +32,20 @@ func ValidatorRegistryRoot(vals []*ethpb.Validator) ([32]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
|
func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
|
||||||
roots := make([][32]byte, len(validators))
|
|
||||||
hasher := hash.CustomSHA256Hasher()
|
hasher := hash.CustomSHA256Hasher()
|
||||||
for i := 0; i < len(validators); i++ {
|
|
||||||
val, err := validatorRoot(hasher, validators[i])
|
var err error
|
||||||
|
var roots [][32]byte
|
||||||
|
if features.Get().EnableVectorizedHTR {
|
||||||
|
roots, err = optimizedValidatorRoots(validators)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return [32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
return [32]byte{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
roots, err = validatorRoots(hasher, validators)
|
||||||
|
if err != nil {
|
||||||
|
return [32]byte{}, err
|
||||||
}
|
}
|
||||||
roots[i] = val
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit)
|
validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit)
|
||||||
@@ -45,6 +64,42 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validatorRoots(hasher func([]byte) [32]byte, validators []*ethpb.Validator) ([][32]byte, error) {
|
||||||
|
roots := make([][32]byte, len(validators))
|
||||||
|
for i := 0; i < len(validators); i++ {
|
||||||
|
val, err := validatorRoot(hasher, validators[i])
|
||||||
|
if err != nil {
|
||||||
|
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
||||||
|
}
|
||||||
|
roots[i] = val
|
||||||
|
}
|
||||||
|
return roots, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func optimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) {
|
||||||
|
roots := make([][32]byte, 0, len(validators)*validatorFieldRoots)
|
||||||
|
hasher := hash.CustomSHA256Hasher()
|
||||||
|
for i := 0; i < len(validators); i++ {
|
||||||
|
fRoots, err := ValidatorFieldRoots(hasher, validators[i])
|
||||||
|
if err != nil {
|
||||||
|
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
|
||||||
|
}
|
||||||
|
roots = append(roots, fRoots...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A validator's tree can represented with a depth of 3. As log2(8) = 3
|
||||||
|
// Using this property we can lay out all the individual fields of a
|
||||||
|
// validator and hash them in single level using our vectorized routine.
|
||||||
|
for i := 0; i < validatorTreeDepth; i++ {
|
||||||
|
// Overwrite input lists as we are hashing by level
|
||||||
|
// and only need the highest level to proceed.
|
||||||
|
outputLen := len(roots) / 2
|
||||||
|
htr.VectorizedSha256(roots, roots)
|
||||||
|
roots = roots[:outputLen]
|
||||||
|
}
|
||||||
|
return roots, nil
|
||||||
|
}
|
||||||
|
|
||||||
func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
||||||
if validator == nil {
|
if validator == nil {
|
||||||
return [32]byte{}, errors.New("nil validator")
|
return [32]byte{}, errors.New("nil validator")
|
||||||
|
|||||||
32
beacon-chain/state/stateutil/field_root_validator_test.go
Normal file
32
beacon-chain/state/stateutil/field_root_validator_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package stateutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
mathutil "github.com/prysmaticlabs/prysm/math"
|
||||||
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
|
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidatorConstants(t *testing.T) {
|
||||||
|
v := ðpb.Validator{}
|
||||||
|
refV := reflect.ValueOf(v).Elem()
|
||||||
|
numFields := refV.NumField()
|
||||||
|
numOfValFields := 0
|
||||||
|
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
if strings.Contains(refV.Type().Field(i).Name, "state") ||
|
||||||
|
strings.Contains(refV.Type().Field(i).Name, "sizeCache") ||
|
||||||
|
strings.Contains(refV.Type().Field(i).Name, "unknownFields") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
numOfValFields++
|
||||||
|
}
|
||||||
|
assert.Equal(t, validatorFieldRoots, numOfValFields)
|
||||||
|
assert.Equal(t, uint64(validatorFieldRoots), mathutil.PowerOf2(validatorTreeDepth))
|
||||||
|
|
||||||
|
_, err := ValidatorRegistryRoot([]*ethpb.Validator{v})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -5,8 +5,10 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prysmaticlabs/prysm/config/features"
|
||||||
"github.com/prysmaticlabs/prysm/container/trie"
|
"github.com/prysmaticlabs/prysm/container/trie"
|
||||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||||
|
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||||
"github.com/prysmaticlabs/prysm/math"
|
"github.com/prysmaticlabs/prysm/math"
|
||||||
)
|
)
|
||||||
@@ -61,25 +63,43 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte {
|
|||||||
layers[0] = transformedLeaves
|
layers[0] = transformedLeaves
|
||||||
buffer := bytes.NewBuffer([]byte{})
|
buffer := bytes.NewBuffer([]byte{})
|
||||||
buffer.Grow(64)
|
buffer.Grow(64)
|
||||||
|
|
||||||
for i := 0; i < int(depth); i++ {
|
for i := 0; i < int(depth); i++ {
|
||||||
oddNodeLength := len(layers[i])%2 == 1
|
layerLen := len(layers[i])
|
||||||
if oddNodeLength {
|
oddNodeLength := layerLen%2 == 1
|
||||||
zerohash := trie.ZeroHashes[i]
|
if features.Get().EnableVectorizedHTR {
|
||||||
layers[i] = append(layers[i], &zerohash)
|
if oddNodeLength {
|
||||||
|
zerohash := trie.ZeroHashes[i]
|
||||||
|
elements = append(elements, zerohash)
|
||||||
|
layerLen++
|
||||||
|
}
|
||||||
|
|
||||||
|
layers[i+1] = make([]*[32]byte, layerLen/2)
|
||||||
|
newElems := make([][32]byte, layerLen/2)
|
||||||
|
htr.VectorizedSha256(elements, newElems)
|
||||||
|
elements = newElems
|
||||||
|
for j := range elements {
|
||||||
|
layers[i+1][j] = &elements[j]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if oddNodeLength {
|
||||||
|
zerohash := trie.ZeroHashes[i]
|
||||||
|
layers[i] = append(layers[i], &zerohash)
|
||||||
|
}
|
||||||
|
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
|
||||||
|
for j := 0; j < len(layers[i]); j += 2 {
|
||||||
|
buffer.Write(layers[i][j][:])
|
||||||
|
buffer.Write(layers[i][j+1][:])
|
||||||
|
concat := hasher(buffer.Bytes())
|
||||||
|
updatedValues = append(updatedValues, &concat)
|
||||||
|
buffer.Reset()
|
||||||
|
}
|
||||||
|
// remove zerohash node from tree
|
||||||
|
if oddNodeLength {
|
||||||
|
layers[i] = layers[i][:len(layers[i])-1]
|
||||||
|
}
|
||||||
|
layers[i+1] = updatedValues
|
||||||
}
|
}
|
||||||
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
|
|
||||||
for j := 0; j < len(layers[i]); j += 2 {
|
|
||||||
buffer.Write(layers[i][j][:])
|
|
||||||
buffer.Write(layers[i][j+1][:])
|
|
||||||
concat := hasher(buffer.Bytes())
|
|
||||||
updatedValues = append(updatedValues, &concat)
|
|
||||||
buffer.Reset()
|
|
||||||
}
|
|
||||||
// remove zerohash node from tree
|
|
||||||
if oddNodeLength {
|
|
||||||
layers[i] = layers[i][:len(layers[i])-1]
|
|
||||||
}
|
|
||||||
layers[i+1] = updatedValues
|
|
||||||
}
|
}
|
||||||
return layers
|
return layers
|
||||||
}
|
}
|
||||||
@@ -277,18 +297,24 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte,
|
|||||||
chunkBuffer := bytes.NewBuffer([]byte{})
|
chunkBuffer := bytes.NewBuffer([]byte{})
|
||||||
chunkBuffer.Grow(64)
|
chunkBuffer.Grow(64)
|
||||||
for len(hashLayer) > 1 && i < len(layers) {
|
for len(hashLayer) > 1 && i < len(layers) {
|
||||||
layer := make([][32]byte, len(hashLayer)/2)
|
|
||||||
if !math.IsPowerOf2(uint64(len(hashLayer))) {
|
if !math.IsPowerOf2(uint64(len(hashLayer))) {
|
||||||
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
|
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
|
||||||
}
|
}
|
||||||
for j := 0; j < len(hashLayer); j += 2 {
|
if features.Get().EnableVectorizedHTR {
|
||||||
chunkBuffer.Write(hashLayer[j][:])
|
newLayer := make([][32]byte, len(hashLayer)/2)
|
||||||
chunkBuffer.Write(hashLayer[j+1][:])
|
htr.VectorizedSha256(hashLayer, newLayer)
|
||||||
hashedChunk := hasher(chunkBuffer.Bytes())
|
hashLayer = newLayer
|
||||||
layer[j/2] = hashedChunk
|
} else {
|
||||||
chunkBuffer.Reset()
|
layer := make([][32]byte, len(hashLayer)/2)
|
||||||
|
for j := 0; j < len(hashLayer); j += 2 {
|
||||||
|
chunkBuffer.Write(hashLayer[j][:])
|
||||||
|
chunkBuffer.Write(hashLayer[j+1][:])
|
||||||
|
hashedChunk := hasher(chunkBuffer.Bytes())
|
||||||
|
layer[j/2] = hashedChunk
|
||||||
|
chunkBuffer.Reset()
|
||||||
|
}
|
||||||
|
hashLayer = layer
|
||||||
}
|
}
|
||||||
hashLayer = layer
|
|
||||||
layers[i] = hashLayer
|
layers[i] = hashLayer
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
types "github.com/prysmaticlabs/eth2-types"
|
types "github.com/prysmaticlabs/eth2-types"
|
||||||
|
"github.com/prysmaticlabs/prysm/beacon-chain/state"
|
||||||
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
|
||||||
|
"github.com/prysmaticlabs/prysm/config/features"
|
||||||
"github.com/prysmaticlabs/prysm/config/params"
|
"github.com/prysmaticlabs/prysm/config/params"
|
||||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||||
@@ -18,15 +20,56 @@ func TestReturnTrieLayer_OK(t *testing.T) {
|
|||||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||||
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
blockRts := newState.BlockRoots()
|
roots := retrieveBlockRoots(newState)
|
||||||
roots := make([][32]byte, 0, len(blockRts))
|
|
||||||
for _, rt := range blockRts {
|
|
||||||
roots = append(roots, bytesutil.ToBytes32(rt))
|
|
||||||
}
|
|
||||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
newRoot := *layers[len(layers)-1][0]
|
newRoot := *layers[len(layers)-1][0]
|
||||||
assert.Equal(t, root, newRoot)
|
assert.Equal(t, root, newRoot)
|
||||||
|
|
||||||
|
flags := &features.Flags{}
|
||||||
|
flags.EnableVectorizedHTR = true
|
||||||
|
reset := features.InitWithReset(flags)
|
||||||
|
defer reset()
|
||||||
|
|
||||||
|
layers, err = stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
lastRoot := *layers[len(layers)-1][0]
|
||||||
|
assert.Equal(t, root, lastRoot)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReturnTrieLayer_NormalAlgorithm(b *testing.B) {
|
||||||
|
newState, _ := util.DeterministicGenesisState(b, 32)
|
||||||
|
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||||
|
require.NoError(b, err)
|
||||||
|
roots := retrieveBlockRoots(newState)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||||
|
assert.NoError(b, err)
|
||||||
|
newRoot := *layers[len(layers)-1][0]
|
||||||
|
assert.Equal(b, root, newRoot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReturnTrieLayer_VectorizedAlgorithm(b *testing.B) {
|
||||||
|
flags := &features.Flags{}
|
||||||
|
flags.EnableVectorizedHTR = true
|
||||||
|
reset := features.InitWithReset(flags)
|
||||||
|
defer reset()
|
||||||
|
|
||||||
|
newState, _ := util.DeterministicGenesisState(b, 32)
|
||||||
|
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||||
|
require.NoError(b, err)
|
||||||
|
roots := retrieveBlockRoots(newState)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||||
|
assert.NoError(b, err)
|
||||||
|
newRoot := *layers[len(layers)-1][0]
|
||||||
|
assert.Equal(b, root, newRoot)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReturnTrieLayerVariable_OK(t *testing.T) {
|
func TestReturnTrieLayerVariable_OK(t *testing.T) {
|
||||||
@@ -46,15 +89,73 @@ func TestReturnTrieLayerVariable_OK(t *testing.T) {
|
|||||||
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, root, newRoot)
|
assert.Equal(t, root, newRoot)
|
||||||
|
|
||||||
|
flags := &features.Flags{}
|
||||||
|
flags.EnableVectorizedHTR = true
|
||||||
|
reset := features.InitWithReset(flags)
|
||||||
|
defer reset()
|
||||||
|
|
||||||
|
layers = stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||||
|
lastRoot := *layers[len(layers)-1][0]
|
||||||
|
lastRoot, err = stateutil.AddInMixin(lastRoot, uint64(len(validators)))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, root, lastRoot)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReturnTrieLayerVariable_NormalAlgorithm(b *testing.B) {
|
||||||
|
newState, _ := util.DeterministicGenesisState(b, 16000)
|
||||||
|
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
|
||||||
|
require.NoError(b, err)
|
||||||
|
hasher := hash.CustomSHA256Hasher()
|
||||||
|
validators := newState.Validators()
|
||||||
|
roots := make([][32]byte, 0, len(validators))
|
||||||
|
for _, val := range validators {
|
||||||
|
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
|
||||||
|
require.NoError(b, err)
|
||||||
|
roots = append(roots, rt)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||||
|
newRoot := *layers[len(layers)-1][0]
|
||||||
|
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||||
|
require.NoError(b, err)
|
||||||
|
assert.Equal(b, root, newRoot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReturnTrieLayerVariable_VectorizedAlgorithm(b *testing.B) {
|
||||||
|
flags := &features.Flags{}
|
||||||
|
flags.EnableVectorizedHTR = true
|
||||||
|
reset := features.InitWithReset(flags)
|
||||||
|
defer reset()
|
||||||
|
|
||||||
|
newState, _ := util.DeterministicGenesisState(b, 16000)
|
||||||
|
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
|
||||||
|
require.NoError(b, err)
|
||||||
|
hasher := hash.CustomSHA256Hasher()
|
||||||
|
validators := newState.Validators()
|
||||||
|
roots := make([][32]byte, 0, len(validators))
|
||||||
|
for _, val := range validators {
|
||||||
|
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
|
||||||
|
require.NoError(b, err)
|
||||||
|
roots = append(roots, rt)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
|
||||||
|
newRoot := *layers[len(layers)-1][0]
|
||||||
|
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
|
||||||
|
require.NoError(b, err)
|
||||||
|
assert.Equal(b, root, newRoot)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) {
|
func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) {
|
||||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||||
blockRts := newState.BlockRoots()
|
roots := retrieveBlockRoots(newState)
|
||||||
roots := make([][32]byte, 0, len(blockRts))
|
|
||||||
for _, rt := range blockRts {
|
|
||||||
roots = append(roots, bytesutil.ToBytes32(rt))
|
|
||||||
}
|
|
||||||
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -120,3 +221,12 @@ func TestMerkleizeTrieLeaves_BadHashLayer(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.ErrorContains(t, "hash layer is a non power of 2", err)
|
assert.ErrorContains(t, "hash layer is a non power of 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func retrieveBlockRoots(b state.BeaconState) [][32]byte {
|
||||||
|
blockRts := b.BlockRoots()
|
||||||
|
roots := make([][32]byte, 0, len(blockRts))
|
||||||
|
for _, rt := range blockRts {
|
||||||
|
roots = append(roots, bytesutil.ToBytes32(rt))
|
||||||
|
}
|
||||||
|
return roots
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/prysmaticlabs/prysm/config/features"
|
||||||
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
|
||||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||||
|
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
@@ -14,6 +16,16 @@ import (
|
|||||||
// ValidatorRootWithHasher describes a method from which the hash tree root
|
// ValidatorRootWithHasher describes a method from which the hash tree root
|
||||||
// of a validator is returned.
|
// of a validator is returned.
|
||||||
func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
|
||||||
|
fieldRoots, err := ValidatorFieldRoots(hasher, validator)
|
||||||
|
if err != nil {
|
||||||
|
return [32]byte{}, err
|
||||||
|
}
|
||||||
|
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatorFieldRoots describes a method from which the hash tree root
|
||||||
|
// of a validator is returned.
|
||||||
|
func ValidatorFieldRoots(hasher ssz.HashFn, validator *ethpb.Validator) ([][32]byte, error) {
|
||||||
var fieldRoots [][32]byte
|
var fieldRoots [][32]byte
|
||||||
if validator != nil {
|
if validator != nil {
|
||||||
pubkey := bytesutil.ToBytes48(validator.PublicKey)
|
pubkey := bytesutil.ToBytes48(validator.PublicKey)
|
||||||
@@ -40,18 +52,25 @@ func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32
|
|||||||
binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch))
|
binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch))
|
||||||
|
|
||||||
// Public key.
|
// Public key.
|
||||||
pubKeyChunks, err := ssz.Pack([][]byte{pubkey[:]})
|
pubKeyChunks, err := ssz.PackByChunk([][]byte{pubkey[:]})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return [32]byte{}, err
|
return [][32]byte{}, err
|
||||||
}
|
}
|
||||||
pubKeyRoot, err := ssz.BitwiseMerkleize(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
|
var pubKeyRoot [32]byte
|
||||||
if err != nil {
|
if features.Get().EnableVectorizedHTR {
|
||||||
return [32]byte{}, err
|
outputChunk := make([][32]byte, 1)
|
||||||
|
htr.VectorizedSha256(pubKeyChunks, outputChunk)
|
||||||
|
pubKeyRoot = outputChunk[0]
|
||||||
|
} else {
|
||||||
|
pubKeyRoot, err = ssz.BitwiseMerkleizeArrays(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
|
||||||
|
if err != nil {
|
||||||
|
return [][32]byte{}, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf,
|
fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf,
|
||||||
activationBuf, exitBuf, withdrawalBuf}
|
activationBuf, exitBuf, withdrawalBuf}
|
||||||
}
|
}
|
||||||
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
return fieldRoots, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of
|
// Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of
|
||||||
|
|||||||
@@ -177,6 +177,24 @@ func TestBeaconState_HashTreeRoot(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkBeaconState(b *testing.B) {
|
||||||
|
testState, _ := util.DeterministicGenesisState(b, 16000)
|
||||||
|
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.Run("Vectorized SHA256", func(b *testing.B) {
|
||||||
|
st, err := v1.InitializeFromProtoUnsafe(pbState)
|
||||||
|
require.NoError(b, err)
|
||||||
|
_, err = st.HashTreeRoot(context.Background())
|
||||||
|
assert.NoError(b, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Current SHA256", func(b *testing.B) {
|
||||||
|
_, err := pbState.HashTreeRoot()
|
||||||
|
require.NoError(b, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
|
func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
|
||||||
testState, _ := util.DeterministicGenesisState(t, 64)
|
testState, _ := util.DeterministicGenesisState(t, 64)
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ go_library(
|
|||||||
"//crypto/bls:go_default_library",
|
"//crypto/bls:go_default_library",
|
||||||
"//crypto/rand:go_default_library",
|
"//crypto/rand:go_default_library",
|
||||||
"//encoding/bytesutil:go_default_library",
|
"//encoding/bytesutil:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz/equality:go_default_library",
|
||||||
"//monitoring/tracing:go_default_library",
|
"//monitoring/tracing:go_default_library",
|
||||||
"//network/forks:go_default_library",
|
"//network/forks:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
@@ -195,7 +195,7 @@ go_test(
|
|||||||
"//crypto/bls:go_default_library",
|
"//crypto/bls:go_default_library",
|
||||||
"//crypto/rand:go_default_library",
|
"//crypto/rand:go_default_library",
|
||||||
"//encoding/bytesutil:go_default_library",
|
"//encoding/bytesutil:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz/equality:go_default_library",
|
||||||
"//network/forks:go_default_library",
|
"//network/forks:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
"//proto/prysm/v1alpha1/attestation:go_default_library",
|
"//proto/prysm/v1alpha1/attestation:go_default_library",
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/prysmaticlabs/prysm/config/params"
|
"github.com/prysmaticlabs/prysm/config/params"
|
||||||
"github.com/prysmaticlabs/prysm/crypto/rand"
|
"github.com/prysmaticlabs/prysm/crypto/rand"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
"github.com/prysmaticlabs/prysm/monitoring/tracing"
|
"github.com/prysmaticlabs/prysm/monitoring/tracing"
|
||||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block"
|
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block"
|
||||||
"github.com/prysmaticlabs/prysm/time/slots"
|
"github.com/prysmaticlabs/prysm/time/slots"
|
||||||
@@ -328,7 +328,7 @@ func (s *Service) deleteBlockFromPendingQueue(slot types.Slot, b block.SignedBea
|
|||||||
|
|
||||||
newBlks := make([]block.SignedBeaconBlock, 0, len(blks))
|
newBlks := make([]block.SignedBeaconBlock, 0, len(blks))
|
||||||
for _, blk := range blks {
|
for _, blk := range blks {
|
||||||
if ssz.DeepEqual(blk.Proto(), b.Proto()) {
|
if equality.DeepEqual(blk.Proto(), b.Proto()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newBlks = append(newBlks, blk)
|
newBlks = append(newBlks, blk)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
"github.com/prysmaticlabs/prysm/beacon-chain/p2p"
|
"github.com/prysmaticlabs/prysm/beacon-chain/p2p"
|
||||||
p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
|
p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
|
||||||
"github.com/prysmaticlabs/prysm/config/params"
|
"github.com/prysmaticlabs/prysm/config/params"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata"
|
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata"
|
||||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
||||||
@@ -125,7 +125,7 @@ func TestMetadataRPCHandler_SendsMetadata(t *testing.T) {
|
|||||||
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||||
t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ func TestMetadataRPCHandler_SendsMetadataAltair(t *testing.T) {
|
|||||||
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
|
||||||
t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ type Flags struct {
|
|||||||
CorrectlyInsertOrphanedAtts bool
|
CorrectlyInsertOrphanedAtts bool
|
||||||
CorrectlyPruneCanonicalAtts bool
|
CorrectlyPruneCanonicalAtts bool
|
||||||
|
|
||||||
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
|
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
|
||||||
|
EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines.
|
||||||
|
|
||||||
// KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have
|
// KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have
|
||||||
// changed on disk. This feature is for advanced use cases only.
|
// changed on disk. This feature is for advanced use cases only.
|
||||||
@@ -222,6 +223,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
|
|||||||
logEnabled(enableNativeState)
|
logEnabled(enableNativeState)
|
||||||
cfg.EnableNativeState = true
|
cfg.EnableNativeState = true
|
||||||
}
|
}
|
||||||
|
if ctx.Bool(enableVecHTR.Name) {
|
||||||
|
logEnabled(enableVecHTR)
|
||||||
|
cfg.EnableVectorizedHTR = true
|
||||||
|
}
|
||||||
Init(cfg)
|
Init(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -134,11 +134,16 @@ var (
|
|||||||
Name: "enable-native-state",
|
Name: "enable-native-state",
|
||||||
Usage: "Enables representing the beacon state as a pure Go struct.",
|
Usage: "Enables representing the beacon state as a pure Go struct.",
|
||||||
}
|
}
|
||||||
|
enableVecHTR = &cli.BoolFlag{
|
||||||
|
Name: "enable-vectorized-htr",
|
||||||
|
Usage: "Enables new go sha256 library which utilizes optimized routines for merkle trees",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// devModeFlags holds list of flags that are set when development mode is on.
|
// devModeFlags holds list of flags that are set when development mode is on.
|
||||||
var devModeFlags = []cli.Flag{
|
var devModeFlags = []cli.Flag{
|
||||||
enablePeerScorer,
|
enablePeerScorer,
|
||||||
|
enableVecHTR,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
|
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
|
||||||
@@ -183,6 +188,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
|
|||||||
disableBatchGossipVerification,
|
disableBatchGossipVerification,
|
||||||
disableBalanceTrieComputation,
|
disableBalanceTrieComputation,
|
||||||
enableNativeState,
|
enableNativeState,
|
||||||
|
enableVecHTR,
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.
|
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.
|
||||||
|
|||||||
9
crypto/hash/htr/BUILD.bazel
Normal file
9
crypto/hash/htr/BUILD.bazel
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
load("@prysm//tools/go:def.bzl", "go_library")
|
||||||
|
|
||||||
|
go_library(
|
||||||
|
name = "go_default_library",
|
||||||
|
srcs = ["hashtree.go"],
|
||||||
|
importpath = "github.com/prysmaticlabs/prysm/crypto/hash/htr",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"],
|
||||||
|
)
|
||||||
17
crypto/hash/htr/hashtree.go
Normal file
17
crypto/hash/htr/hashtree.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package htr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prysmaticlabs/gohashtree"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VectorizedSha256 takes a list of roots and hashes them using CPU
|
||||||
|
// specific vector instructions. Depending on host machine's specific
|
||||||
|
// hardware configuration, using this routine can lead to a significant
|
||||||
|
// performance improvement compared to the default method of hashing
|
||||||
|
// lists.
|
||||||
|
func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) {
|
||||||
|
err := gohashtree.Hash(outputList, inputList)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
7
deps.bzl
7
deps.bzl
@@ -2982,6 +2982,13 @@ def prysm_deps():
|
|||||||
sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=",
|
sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=",
|
||||||
version = "v0.0.0-20210809151128-385d8c5e3fb7",
|
version = "v0.0.0-20210809151128-385d8c5e3fb7",
|
||||||
)
|
)
|
||||||
|
go_repository(
|
||||||
|
name = "com_github_prysmaticlabs_gohashtree",
|
||||||
|
importpath = "github.com/prysmaticlabs/gohashtree",
|
||||||
|
sum = "h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=",
|
||||||
|
version = "v0.0.0-20220208111633-0606f58df32f",
|
||||||
|
)
|
||||||
|
|
||||||
go_repository(
|
go_repository(
|
||||||
name = "com_github_prysmaticlabs_prombbolt",
|
name = "com_github_prysmaticlabs_prombbolt",
|
||||||
importpath = "github.com/prysmaticlabs/prombbolt",
|
importpath = "github.com/prysmaticlabs/prombbolt",
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
|||||||
go_library(
|
go_library(
|
||||||
name = "go_default_library",
|
name = "go_default_library",
|
||||||
srcs = [
|
srcs = [
|
||||||
"deep_equal.go",
|
|
||||||
"hashers.go",
|
"hashers.go",
|
||||||
"helpers.go",
|
"helpers.go",
|
||||||
"htrutils.go",
|
"htrutils.go",
|
||||||
@@ -12,16 +11,16 @@ go_library(
|
|||||||
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
|
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//config/features:go_default_library",
|
||||||
"//config/fieldparams:go_default_library",
|
"//config/fieldparams:go_default_library",
|
||||||
"//container/trie:go_default_library",
|
"//container/trie:go_default_library",
|
||||||
"//crypto/hash:go_default_library",
|
"//crypto/hash:go_default_library",
|
||||||
|
"//crypto/hash/htr:go_default_library",
|
||||||
"//encoding/bytesutil:go_default_library",
|
"//encoding/bytesutil:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
"@com_github_minio_sha256_simd//:go_default_library",
|
"@com_github_minio_sha256_simd//:go_default_library",
|
||||||
"@com_github_pkg_errors//:go_default_library",
|
"@com_github_pkg_errors//:go_default_library",
|
||||||
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
|
||||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||||
"@org_golang_google_protobuf//proto:go_default_library",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +28,6 @@ go_test(
|
|||||||
name = "go_default_test",
|
name = "go_default_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = [
|
srcs = [
|
||||||
"deep_equal_test.go",
|
|
||||||
"hashers_test.go",
|
"hashers_test.go",
|
||||||
"helpers_test.go",
|
"helpers_test.go",
|
||||||
"htrutils_test.go",
|
"htrutils_test.go",
|
||||||
|
|||||||
22
encoding/ssz/equality/BUILD.bazel
Normal file
22
encoding/ssz/equality/BUILD.bazel
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||||
|
|
||||||
|
go_library(
|
||||||
|
name = "go_default_library",
|
||||||
|
srcs = ["deep_equal.go"],
|
||||||
|
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz/equality",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
||||||
|
"@org_golang_google_protobuf//proto:go_default_library",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
go_test(
|
||||||
|
name = "go_default_test",
|
||||||
|
srcs = ["deep_equal_test.go"],
|
||||||
|
deps = [
|
||||||
|
":go_default_library",
|
||||||
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
|
"//testing/assert:go_default_library",
|
||||||
|
],
|
||||||
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package ssz
|
package equality
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -1,34 +1,34 @@
|
|||||||
package ssz_test
|
package equality_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDeepEqualBasicTypes(t *testing.T) {
|
func TestDeepEqualBasicTypes(t *testing.T) {
|
||||||
assert.Equal(t, true, ssz.DeepEqual(true, true))
|
assert.Equal(t, true, equality.DeepEqual(true, true))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(true, false))
|
assert.Equal(t, false, equality.DeepEqual(true, false))
|
||||||
|
|
||||||
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222)))
|
assert.Equal(t, true, equality.DeepEqual(byte(222), byte(222)))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111)))
|
assert.Equal(t, false, equality.DeepEqual(byte(222), byte(111)))
|
||||||
|
|
||||||
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890)))
|
assert.Equal(t, true, equality.DeepEqual(uint64(1234567890), uint64(1234567890)))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210)))
|
assert.Equal(t, false, equality.DeepEqual(uint64(1234567890), uint64(987653210)))
|
||||||
|
|
||||||
assert.Equal(t, true, ssz.DeepEqual("hello", "hello"))
|
assert.Equal(t, true, equality.DeepEqual("hello", "hello"))
|
||||||
assert.Equal(t, false, ssz.DeepEqual("hello", "world"))
|
assert.Equal(t, false, equality.DeepEqual("hello", "world"))
|
||||||
|
|
||||||
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
|
assert.Equal(t, true, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
|
||||||
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
|
assert.Equal(t, false, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
|
||||||
|
|
||||||
var nilSlice1, nilSlice2 []byte
|
var nilSlice1, nilSlice2 []byte
|
||||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2))
|
assert.Equal(t, true, equality.DeepEqual(nilSlice1, nilSlice2))
|
||||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{}))
|
assert.Equal(t, true, equality.DeepEqual(nilSlice1, []byte{}))
|
||||||
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
|
assert.Equal(t, true, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
|
||||||
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
|
assert.Equal(t, false, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeepEqualStructs(t *testing.T) {
|
func TestDeepEqualStructs(t *testing.T) {
|
||||||
@@ -39,8 +39,8 @@ func TestDeepEqualStructs(t *testing.T) {
|
|||||||
store1 := Store{uint64(1234), nil}
|
store1 := Store{uint64(1234), nil}
|
||||||
store2 := Store{uint64(1234), []byte{}}
|
store2 := Store{uint64(1234), []byte{}}
|
||||||
store3 := Store{uint64(4321), []byte{}}
|
store3 := Store{uint64(4321), []byte{}}
|
||||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
assert.Equal(t, true, equality.DeepEqual(store1, store2))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
assert.Equal(t, false, equality.DeepEqual(store1, store3))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeepEqualStructs_Unexported(t *testing.T) {
|
func TestDeepEqualStructs_Unexported(t *testing.T) {
|
||||||
@@ -53,14 +53,14 @@ func TestDeepEqualStructs_Unexported(t *testing.T) {
|
|||||||
store2 := Store{uint64(1234), []byte{}, "hi there"}
|
store2 := Store{uint64(1234), []byte{}, "hi there"}
|
||||||
store3 := Store{uint64(4321), []byte{}, "wow"}
|
store3 := Store{uint64(4321), []byte{}, "wow"}
|
||||||
store4 := Store{uint64(4321), []byte{}, "bow wow"}
|
store4 := Store{uint64(4321), []byte{}, "bow wow"}
|
||||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
assert.Equal(t, true, equality.DeepEqual(store1, store2))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
assert.Equal(t, false, equality.DeepEqual(store1, store3))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(store3, store4))
|
assert.Equal(t, false, equality.DeepEqual(store3, store4))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeepEqualProto(t *testing.T) {
|
func TestDeepEqualProto(t *testing.T) {
|
||||||
var fork1, fork2 *ethpb.Fork
|
var fork1, fork2 *ethpb.Fork
|
||||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2))
|
assert.Equal(t, true, equality.DeepEqual(fork1, fork2))
|
||||||
|
|
||||||
fork1 = ðpb.Fork{
|
fork1 = ðpb.Fork{
|
||||||
PreviousVersion: []byte{123},
|
PreviousVersion: []byte{123},
|
||||||
@@ -72,8 +72,8 @@ func TestDeepEqualProto(t *testing.T) {
|
|||||||
CurrentVersion: []byte{125},
|
CurrentVersion: []byte{125},
|
||||||
Epoch: 1234567890,
|
Epoch: 1234567890,
|
||||||
}
|
}
|
||||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1))
|
assert.Equal(t, true, equality.DeepEqual(fork1, fork1))
|
||||||
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2))
|
assert.Equal(t, false, equality.DeepEqual(fork1, fork2))
|
||||||
|
|
||||||
checkpoint1 := ðpb.Checkpoint{
|
checkpoint1 := ðpb.Checkpoint{
|
||||||
Epoch: 1234567890,
|
Epoch: 1234567890,
|
||||||
@@ -83,7 +83,7 @@ func TestDeepEqualProto(t *testing.T) {
|
|||||||
Epoch: 1234567890,
|
Epoch: 1234567890,
|
||||||
Root: nil,
|
Root: nil,
|
||||||
}
|
}
|
||||||
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2))
|
assert.Equal(t, true, equality.DeepEqual(checkpoint1, checkpoint2))
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_IsProto(t *testing.T) {
|
func Test_IsProto(t *testing.T) {
|
||||||
@@ -125,7 +125,7 @@ func Test_IsProto(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := ssz.IsProto(tt.item); got != tt.want {
|
if got := equality.IsProto(tt.item); got != tt.want {
|
||||||
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
|
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/minio/sha256-simd"
|
"github.com/minio/sha256-simd"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/prysmaticlabs/go-bitfield"
|
"github.com/prysmaticlabs/go-bitfield"
|
||||||
|
"github.com/prysmaticlabs/prysm/config/features"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,6 +51,9 @@ func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]
|
|||||||
if count > limit {
|
if count > limit {
|
||||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||||
}
|
}
|
||||||
|
if features.Get().EnableVectorizedHTR {
|
||||||
|
return MerkleizeList(chunks, limit), nil
|
||||||
|
}
|
||||||
hashFn := NewHasherFunc(hasher)
|
hashFn := NewHasherFunc(hasher)
|
||||||
leafIndexer := func(i uint64) []byte {
|
leafIndexer := func(i uint64) []byte {
|
||||||
return chunks[i]
|
return chunks[i]
|
||||||
@@ -62,6 +66,9 @@ func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint6
|
|||||||
if count > limit {
|
if count > limit {
|
||||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||||
}
|
}
|
||||||
|
if features.Get().EnableVectorizedHTR {
|
||||||
|
return MerkleizeVector(chunks, limit), nil
|
||||||
|
}
|
||||||
hashFn := NewHasherFunc(hasher)
|
hashFn := NewHasherFunc(hasher)
|
||||||
leafIndexer := func(i uint64) []byte {
|
leafIndexer := func(i uint64) []byte {
|
||||||
return chunks[i][:]
|
return chunks[i][:]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ssz
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/prysmaticlabs/prysm/container/trie"
|
"github.com/prysmaticlabs/prysm/container/trie"
|
||||||
|
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Merkleize.go is mostly a directly copy of the same filename from
|
// Merkleize.go is mostly a directly copy of the same filename from
|
||||||
@@ -196,3 +197,40 @@ func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []by
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MerkleizeVector uses our optimized routine to hash a list of 32-byte
|
||||||
|
// elements.
|
||||||
|
func MerkleizeVector(elements [][32]byte, length uint64) [32]byte {
|
||||||
|
depth := Depth(length)
|
||||||
|
// Return zerohash at depth
|
||||||
|
if len(elements) == 0 {
|
||||||
|
return trie.ZeroHashes[depth]
|
||||||
|
}
|
||||||
|
for i := 0; i < int(depth); i++ {
|
||||||
|
layerLen := len(elements)
|
||||||
|
oddNodeLength := layerLen%2 == 1
|
||||||
|
if oddNodeLength {
|
||||||
|
zerohash := trie.ZeroHashes[i]
|
||||||
|
elements = append(elements, zerohash)
|
||||||
|
}
|
||||||
|
outputLen := len(elements) / 2
|
||||||
|
htr.VectorizedSha256(elements, elements)
|
||||||
|
elements = elements[:outputLen]
|
||||||
|
}
|
||||||
|
return elements[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MerkleizeList uses our optimized routine to hash a 2d-list of
|
||||||
|
// elements.
|
||||||
|
func MerkleizeList(elements [][]byte, length uint64) [32]byte {
|
||||||
|
depth := Depth(length)
|
||||||
|
// Return zerohash at depth
|
||||||
|
if len(elements) == 0 {
|
||||||
|
return trie.ZeroHashes[depth]
|
||||||
|
}
|
||||||
|
newElems := make([][32]byte, len(elements))
|
||||||
|
for i := range elements {
|
||||||
|
copy(newElems[i][:], elements[i])
|
||||||
|
}
|
||||||
|
return MerkleizeVector(newElems, length)
|
||||||
|
}
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -253,6 +253,7 @@ require (
|
|||||||
github.com/holiman/uint256 v1.2.0
|
github.com/holiman/uint256 v1.2.0
|
||||||
github.com/peterh/liner v1.2.0 // indirect
|
github.com/peterh/liner v1.2.0 // indirect
|
||||||
github.com/prometheus/tsdb v0.10.0 // indirect
|
github.com/prometheus/tsdb v0.10.0 // indirect
|
||||||
|
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f
|
||||||
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
|
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
|
||||||
google.golang.org/api v0.34.0 // indirect
|
google.golang.org/api v0.34.0 // indirect
|
||||||
google.golang.org/appengine v1.6.7 // indirect
|
google.golang.org/appengine v1.6.7 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -1123,6 +1123,8 @@ github.com/prysmaticlabs/fastssz v0.0.0-20220110145812-fafb696cae88/go.mod h1:AS
|
|||||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
|
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
|
||||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=
|
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=
|
||||||
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
|
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
|
||||||
|
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=
|
||||||
|
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f/go.mod h1:4pWaT30XoEx1j8KNJf3TV+E3mQkaufn7mf+jRNb/Fuk=
|
||||||
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss=
|
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss=
|
||||||
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
|
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
|
||||||
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=
|
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ go_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//config/params:go_default_library",
|
"//config/params:go_default_library",
|
||||||
"//crypto/bls:go_default_library",
|
"//crypto/bls:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz/equality:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
"//proto/prysm/v1alpha1/attestation/aggregation:go_default_library",
|
"//proto/prysm/v1alpha1/attestation/aggregation:go_default_library",
|
||||||
"//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library",
|
"//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library",
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/prysmaticlabs/go-bitfield"
|
"github.com/prysmaticlabs/go-bitfield"
|
||||||
"github.com/prysmaticlabs/prysm/config/params"
|
"github.com/prysmaticlabs/prysm/config/params"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation"
|
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation"
|
||||||
aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing"
|
aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing"
|
||||||
@@ -48,7 +48,7 @@ func TestAggregateAttestations_AggregatePair(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
got, err := AggregatePair(tt.a1, tt.a2)
|
got, err := AggregatePair(tt.a1, tt.a2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, true, ssz.DeepEqual(got, tt.want))
|
require.Equal(t, true, equality.DeepEqual(got, tt.want))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ go_library(
|
|||||||
importpath = "github.com/prysmaticlabs/prysm/testing/assertions",
|
importpath = "github.com/prysmaticlabs/prysm/testing/assertions",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz/equality:go_default_library",
|
||||||
"@com_github_d4l3k_messagediff//:go_default_library",
|
"@com_github_d4l3k_messagediff//:go_default_library",
|
||||||
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
|
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
|
||||||
"@org_golang_google_protobuf//proto:go_default_library",
|
"@org_golang_google_protobuf//proto:go_default_library",
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/d4l3k/messagediff"
|
"github.com/d4l3k/messagediff"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
"github.com/sirupsen/logrus/hooks/test"
|
"github.com/sirupsen/logrus/hooks/test"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
@@ -61,7 +61,7 @@ func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
|
|||||||
|
|
||||||
// DeepSSZEqual compares values using ssz.DeepEqual.
|
// DeepSSZEqual compares values using ssz.DeepEqual.
|
||||||
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
||||||
if !ssz.DeepEqual(expected, actual) {
|
if !equality.DeepEqual(expected, actual) {
|
||||||
errMsg := parseMsg("Values are not equal", msg...)
|
errMsg := parseMsg("Values are not equal", msg...)
|
||||||
_, file, line, _ := runtime.Caller(2)
|
_, file, line, _ := runtime.Caller(2)
|
||||||
diff, _ := messagediff.PrettyDiff(expected, actual)
|
diff, _ := messagediff.PrettyDiff(expected, actual)
|
||||||
@@ -71,7 +71,7 @@ func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
|
|||||||
|
|
||||||
// DeepNotSSZEqual compares values using ssz.DeepEqual.
|
// DeepNotSSZEqual compares values using ssz.DeepEqual.
|
||||||
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
|
||||||
if ssz.DeepEqual(expected, actual) {
|
if equality.DeepEqual(expected, actual) {
|
||||||
errMsg := parseMsg("Values are equal", msg...)
|
errMsg := parseMsg("Values are equal", msg...)
|
||||||
_, file, line, _ := runtime.Caller(2)
|
_, file, line, _ := runtime.Caller(2)
|
||||||
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
|
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ go_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//beacon-chain/core/transition:go_default_library",
|
"//beacon-chain/core/transition:go_default_library",
|
||||||
"//beacon-chain/state/v1:go_default_library",
|
"//beacon-chain/state/v1:go_default_library",
|
||||||
"//encoding/ssz:go_default_library",
|
"//encoding/ssz/equality:go_default_library",
|
||||||
"//proto/prysm/v1alpha1:go_default_library",
|
"//proto/prysm/v1alpha1:go_default_library",
|
||||||
"//proto/prysm/v1alpha1/wrapper:go_default_library",
|
"//proto/prysm/v1alpha1/wrapper:go_default_library",
|
||||||
"//runtime/version:go_default_library",
|
"//runtime/version:go_default_library",
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/kr/pretty"
|
"github.com/kr/pretty"
|
||||||
"github.com/prysmaticlabs/prysm/beacon-chain/core/transition"
|
"github.com/prysmaticlabs/prysm/beacon-chain/core/transition"
|
||||||
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
|
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
|
||||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
|
||||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||||
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
|
||||||
"github.com/prysmaticlabs/prysm/runtime/version"
|
"github.com/prysmaticlabs/prysm/runtime/version"
|
||||||
@@ -191,7 +191,7 @@ func main() {
|
|||||||
if err := dataFetcher(expectedPostStatePath, expectedState); err != nil {
|
if err := dataFetcher(expectedPostStatePath, expectedState); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
if !ssz.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
|
if !equality.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
|
||||||
diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe())
|
diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe())
|
||||||
log.Errorf("Derived state differs from provided post state: %s", diff)
|
log.Errorf("Derived state differs from provided post state: %s", diff)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user