Integration of Vectorized Sha256 In Prysm (#10166)

* add changes

* fix for vectorize

* fix bug

* add new bench

* use new algorithms

* add latest updates

* save progress

* hack even more

* add more changes

* change library

* go mod

* fix deps

* fix dumb bug

* add flag and remove redundant code

* clean up better

* remove those ones

* clean up benches

* clean up benches

* cleanup

* gaz

* revert change

* potuz's review

* potuz's review

* potuz's review

* gaz

* potuz's review

* remove cyclical import

* revert ide changes

* potuz's review

* return
This commit is contained in:
Nishant Das
2022-02-28 21:56:12 +08:00
committed by GitHub
parent 12ba8f3645
commit 339540274b
30 changed files with 491 additions and 95 deletions

View File

@@ -258,3 +258,21 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) {
_, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey)) _, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey))
assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey") assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey")
} }
func BenchmarkBeaconState(b *testing.B) {
testState, _ := util.DeterministicGenesisState(b, 16000)
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(b, err)
b.Run("Vectorized SHA256", func(b *testing.B) {
st, err := v1.InitializeFromProtoUnsafe(pbState)
require.NoError(b, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(b, err)
})
b.Run("Current SHA256", func(b *testing.B) {
_, err := pbState.HashTreeRoot()
require.NoError(b, err)
})
}

View File

@@ -32,10 +32,12 @@ go_library(
], ],
deps = [ deps = [
"//beacon-chain/core/transition/stateutils:go_default_library", "//beacon-chain/core/transition/stateutils:go_default_library",
"//config/features:go_default_library",
"//config/fieldparams:go_default_library", "//config/fieldparams:go_default_library",
"//config/params:go_default_library", "//config/params:go_default_library",
"//container/trie:go_default_library", "//container/trie:go_default_library",
"//crypto/hash:go_default_library", "//crypto/hash:go_default_library",
"//crypto/hash/htr:go_default_library",
"//encoding/bytesutil:go_default_library", "//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz:go_default_library",
"//math:go_default_library", "//math:go_default_library",
@@ -51,6 +53,7 @@ go_test(
srcs = [ srcs = [
"benchmark_test.go", "benchmark_test.go",
"field_root_test.go", "field_root_test.go",
"field_root_validator_test.go",
"reference_bench_test.go", "reference_bench_test.go",
"state_root_test.go", "state_root_test.go",
"trie_helpers_test.go", "trie_helpers_test.go",
@@ -58,11 +61,14 @@ go_test(
], ],
embed = [":go_default_library"], embed = [":go_default_library"],
deps = [ deps = [
"//beacon-chain/state:go_default_library",
"//config/features:go_default_library",
"//config/fieldparams:go_default_library", "//config/fieldparams:go_default_library",
"//config/params:go_default_library", "//config/params:go_default_library",
"//crypto/hash:go_default_library", "//crypto/hash:go_default_library",
"//encoding/bytesutil:go_default_library", "//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz:go_default_library",
"//math:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"//runtime/interop:go_default_library", "//runtime/interop:go_default_library",
"//testing/assert:go_default_library", "//testing/assert:go_default_library",

View File

@@ -5,12 +5,25 @@ import (
"encoding/binary" "encoding/binary"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams" fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
"github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
) )
const (
// number of field roots for the validator object.
validatorFieldRoots = 8
// Depth of tree representation of an individual
// validator.
// NumOfRoots = 2 ^ (TreeDepth)
// 8 = 2 ^ 3
validatorTreeDepth = 3
)
// ValidatorRegistryRoot computes the HashTreeRoot Merkleization of // ValidatorRegistryRoot computes the HashTreeRoot Merkleization of
// a list of validator structs according to the Ethereum // a list of validator structs according to the Ethereum
// Simple Serialize specification. // Simple Serialize specification.
@@ -19,14 +32,20 @@ func ValidatorRegistryRoot(vals []*ethpb.Validator) ([32]byte, error) {
} }
func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) { func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
roots := make([][32]byte, len(validators))
hasher := hash.CustomSHA256Hasher() hasher := hash.CustomSHA256Hasher()
for i := 0; i < len(validators); i++ {
val, err := validatorRoot(hasher, validators[i]) var err error
var roots [][32]byte
if features.Get().EnableVectorizedHTR {
roots, err = optimizedValidatorRoots(validators)
if err != nil { if err != nil {
return [32]byte{}, errors.Wrap(err, "could not compute validators merkleization") return [32]byte{}, err
}
} else {
roots, err = validatorRoots(hasher, validators)
if err != nil {
return [32]byte{}, err
} }
roots[i] = val
} }
validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit) validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit)
@@ -45,6 +64,42 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
return res, nil return res, nil
} }
func validatorRoots(hasher func([]byte) [32]byte, validators []*ethpb.Validator) ([][32]byte, error) {
roots := make([][32]byte, len(validators))
for i := 0; i < len(validators); i++ {
val, err := validatorRoot(hasher, validators[i])
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
}
roots[i] = val
}
return roots, nil
}
func optimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) {
roots := make([][32]byte, 0, len(validators)*validatorFieldRoots)
hasher := hash.CustomSHA256Hasher()
for i := 0; i < len(validators); i++ {
fRoots, err := ValidatorFieldRoots(hasher, validators[i])
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
}
roots = append(roots, fRoots...)
}
// A validator's tree can represented with a depth of 3. As log2(8) = 3
// Using this property we can lay out all the individual fields of a
// validator and hash them in single level using our vectorized routine.
for i := 0; i < validatorTreeDepth; i++ {
// Overwrite input lists as we are hashing by level
// and only need the highest level to proceed.
outputLen := len(roots) / 2
htr.VectorizedSha256(roots, roots)
roots = roots[:outputLen]
}
return roots, nil
}
func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) { func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
if validator == nil { if validator == nil {
return [32]byte{}, errors.New("nil validator") return [32]byte{}, errors.New("nil validator")

View File

@@ -0,0 +1,32 @@
package stateutil
import (
"reflect"
"strings"
"testing"
mathutil "github.com/prysmaticlabs/prysm/math"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
)
func TestValidatorConstants(t *testing.T) {
v := &ethpb.Validator{}
refV := reflect.ValueOf(v).Elem()
numFields := refV.NumField()
numOfValFields := 0
for i := 0; i < numFields; i++ {
if strings.Contains(refV.Type().Field(i).Name, "state") ||
strings.Contains(refV.Type().Field(i).Name, "sizeCache") ||
strings.Contains(refV.Type().Field(i).Name, "unknownFields") {
continue
}
numOfValFields++
}
assert.Equal(t, validatorFieldRoots, numOfValFields)
assert.Equal(t, uint64(validatorFieldRoots), mathutil.PowerOf2(validatorTreeDepth))
_, err := ValidatorRegistryRoot([]*ethpb.Validator{v})
assert.NoError(t, err)
}

View File

@@ -5,8 +5,10 @@ import (
"encoding/binary" "encoding/binary"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/container/trie" "github.com/prysmaticlabs/prysm/container/trie"
"github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/math" "github.com/prysmaticlabs/prysm/math"
) )
@@ -61,25 +63,43 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte {
layers[0] = transformedLeaves layers[0] = transformedLeaves
buffer := bytes.NewBuffer([]byte{}) buffer := bytes.NewBuffer([]byte{})
buffer.Grow(64) buffer.Grow(64)
for i := 0; i < int(depth); i++ { for i := 0; i < int(depth); i++ {
oddNodeLength := len(layers[i])%2 == 1 layerLen := len(layers[i])
if oddNodeLength { oddNodeLength := layerLen%2 == 1
zerohash := trie.ZeroHashes[i] if features.Get().EnableVectorizedHTR {
layers[i] = append(layers[i], &zerohash) if oddNodeLength {
zerohash := trie.ZeroHashes[i]
elements = append(elements, zerohash)
layerLen++
}
layers[i+1] = make([]*[32]byte, layerLen/2)
newElems := make([][32]byte, layerLen/2)
htr.VectorizedSha256(elements, newElems)
elements = newElems
for j := range elements {
layers[i+1][j] = &elements[j]
}
} else {
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
layers[i] = append(layers[i], &zerohash)
}
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
for j := 0; j < len(layers[i]); j += 2 {
buffer.Write(layers[i][j][:])
buffer.Write(layers[i][j+1][:])
concat := hasher(buffer.Bytes())
updatedValues = append(updatedValues, &concat)
buffer.Reset()
}
// remove zerohash node from tree
if oddNodeLength {
layers[i] = layers[i][:len(layers[i])-1]
}
layers[i+1] = updatedValues
} }
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
for j := 0; j < len(layers[i]); j += 2 {
buffer.Write(layers[i][j][:])
buffer.Write(layers[i][j+1][:])
concat := hasher(buffer.Bytes())
updatedValues = append(updatedValues, &concat)
buffer.Reset()
}
// remove zerohash node from tree
if oddNodeLength {
layers[i] = layers[i][:len(layers[i])-1]
}
layers[i+1] = updatedValues
} }
return layers return layers
} }
@@ -277,18 +297,24 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte,
chunkBuffer := bytes.NewBuffer([]byte{}) chunkBuffer := bytes.NewBuffer([]byte{})
chunkBuffer.Grow(64) chunkBuffer.Grow(64)
for len(hashLayer) > 1 && i < len(layers) { for len(hashLayer) > 1 && i < len(layers) {
layer := make([][32]byte, len(hashLayer)/2)
if !math.IsPowerOf2(uint64(len(hashLayer))) { if !math.IsPowerOf2(uint64(len(hashLayer))) {
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer)) return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
} }
for j := 0; j < len(hashLayer); j += 2 { if features.Get().EnableVectorizedHTR {
chunkBuffer.Write(hashLayer[j][:]) newLayer := make([][32]byte, len(hashLayer)/2)
chunkBuffer.Write(hashLayer[j+1][:]) htr.VectorizedSha256(hashLayer, newLayer)
hashedChunk := hasher(chunkBuffer.Bytes()) hashLayer = newLayer
layer[j/2] = hashedChunk } else {
chunkBuffer.Reset() layer := make([][32]byte, len(hashLayer)/2)
for j := 0; j < len(hashLayer); j += 2 {
chunkBuffer.Write(hashLayer[j][:])
chunkBuffer.Write(hashLayer[j+1][:])
hashedChunk := hasher(chunkBuffer.Bytes())
layer[j/2] = hashedChunk
chunkBuffer.Reset()
}
hashLayer = layer
} }
hashLayer = layer
layers[i] = hashLayer layers[i] = hashLayer
i++ i++
} }

View File

@@ -4,7 +4,9 @@ import (
"testing" "testing"
types "github.com/prysmaticlabs/eth2-types" types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/bytesutil"
@@ -18,15 +20,56 @@ func TestReturnTrieLayer_OK(t *testing.T) {
newState, _ := util.DeterministicGenesisState(t, 32) newState, _ := util.DeterministicGenesisState(t, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(t, err) require.NoError(t, err)
blockRts := newState.BlockRoots() roots := retrieveBlockRoots(newState)
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(t, err) assert.NoError(t, err)
newRoot := *layers[len(layers)-1][0] newRoot := *layers[len(layers)-1][0]
assert.Equal(t, root, newRoot) assert.Equal(t, root, newRoot)
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
layers, err = stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(t, err)
lastRoot := *layers[len(layers)-1][0]
assert.Equal(t, root, lastRoot)
}
func BenchmarkReturnTrieLayer_NormalAlgorithm(b *testing.B) {
newState, _ := util.DeterministicGenesisState(b, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(b, err)
roots := retrieveBlockRoots(newState)
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(b, err)
newRoot := *layers[len(layers)-1][0]
assert.Equal(b, root, newRoot)
}
}
func BenchmarkReturnTrieLayer_VectorizedAlgorithm(b *testing.B) {
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
newState, _ := util.DeterministicGenesisState(b, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(b, err)
roots := retrieveBlockRoots(newState)
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(b, err)
newRoot := *layers[len(layers)-1][0]
assert.Equal(b, root, newRoot)
}
} }
func TestReturnTrieLayerVariable_OK(t *testing.T) { func TestReturnTrieLayerVariable_OK(t *testing.T) {
@@ -46,15 +89,73 @@ func TestReturnTrieLayerVariable_OK(t *testing.T) {
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators))) newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, root, newRoot) assert.Equal(t, root, newRoot)
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
layers = stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
lastRoot := *layers[len(layers)-1][0]
lastRoot, err = stateutil.AddInMixin(lastRoot, uint64(len(validators)))
require.NoError(t, err)
assert.Equal(t, root, lastRoot)
}
func BenchmarkReturnTrieLayerVariable_NormalAlgorithm(b *testing.B) {
newState, _ := util.DeterministicGenesisState(b, 16000)
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
require.NoError(b, err)
hasher := hash.CustomSHA256Hasher()
validators := newState.Validators()
roots := make([][32]byte, 0, len(validators))
for _, val := range validators {
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
require.NoError(b, err)
roots = append(roots, rt)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
newRoot := *layers[len(layers)-1][0]
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(b, err)
assert.Equal(b, root, newRoot)
}
}
func BenchmarkReturnTrieLayerVariable_VectorizedAlgorithm(b *testing.B) {
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
newState, _ := util.DeterministicGenesisState(b, 16000)
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
require.NoError(b, err)
hasher := hash.CustomSHA256Hasher()
validators := newState.Validators()
roots := make([][32]byte, 0, len(validators))
for _, val := range validators {
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
require.NoError(b, err)
roots = append(roots, rt)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
newRoot := *layers[len(layers)-1][0]
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(b, err)
assert.Equal(b, root, newRoot)
}
} }
func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) { func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) {
newState, _ := util.DeterministicGenesisState(t, 32) newState, _ := util.DeterministicGenesisState(t, 32)
blockRts := newState.BlockRoots() roots := retrieveBlockRoots(newState)
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
require.NoError(t, err) require.NoError(t, err)
@@ -120,3 +221,12 @@ func TestMerkleizeTrieLeaves_BadHashLayer(t *testing.T) {
}) })
assert.ErrorContains(t, "hash layer is a non power of 2", err) assert.ErrorContains(t, "hash layer is a non power of 2", err)
} }
func retrieveBlockRoots(b state.BeaconState) [][32]byte {
blockRts := b.BlockRoots()
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
return roots
}

View File

@@ -4,8 +4,10 @@ import (
"encoding/binary" "encoding/binary"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams" fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
"github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/bytesutil"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
@@ -14,6 +16,16 @@ import (
// ValidatorRootWithHasher describes a method from which the hash tree root // ValidatorRootWithHasher describes a method from which the hash tree root
// of a validator is returned. // of a validator is returned.
func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) { func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
fieldRoots, err := ValidatorFieldRoots(hasher, validator)
if err != nil {
return [32]byte{}, err
}
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
}
// ValidatorFieldRoots describes a method from which the hash tree root
// of a validator is returned.
func ValidatorFieldRoots(hasher ssz.HashFn, validator *ethpb.Validator) ([][32]byte, error) {
var fieldRoots [][32]byte var fieldRoots [][32]byte
if validator != nil { if validator != nil {
pubkey := bytesutil.ToBytes48(validator.PublicKey) pubkey := bytesutil.ToBytes48(validator.PublicKey)
@@ -40,18 +52,25 @@ func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32
binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch)) binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch))
// Public key. // Public key.
pubKeyChunks, err := ssz.Pack([][]byte{pubkey[:]}) pubKeyChunks, err := ssz.PackByChunk([][]byte{pubkey[:]})
if err != nil { if err != nil {
return [32]byte{}, err return [][32]byte{}, err
} }
pubKeyRoot, err := ssz.BitwiseMerkleize(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks))) var pubKeyRoot [32]byte
if err != nil { if features.Get().EnableVectorizedHTR {
return [32]byte{}, err outputChunk := make([][32]byte, 1)
htr.VectorizedSha256(pubKeyChunks, outputChunk)
pubKeyRoot = outputChunk[0]
} else {
pubKeyRoot, err = ssz.BitwiseMerkleizeArrays(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
if err != nil {
return [][32]byte{}, err
}
} }
fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf, fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf,
activationBuf, exitBuf, withdrawalBuf} activationBuf, exitBuf, withdrawalBuf}
} }
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots))) return fieldRoots, nil
} }
// Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of // Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of

View File

@@ -177,6 +177,24 @@ func TestBeaconState_HashTreeRoot(t *testing.T) {
} }
} }
func BenchmarkBeaconState(b *testing.B) {
testState, _ := util.DeterministicGenesisState(b, 16000)
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(b, err)
b.Run("Vectorized SHA256", func(b *testing.B) {
st, err := v1.InitializeFromProtoUnsafe(pbState)
require.NoError(b, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(b, err)
})
b.Run("Current SHA256", func(b *testing.B) {
_, err := pbState.HashTreeRoot()
require.NoError(b, err)
})
}
func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) { func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
testState, _ := util.DeterministicGenesisState(t, 64) testState, _ := util.DeterministicGenesisState(t, 64)

View File

@@ -86,7 +86,7 @@ go_library(
"//crypto/bls:go_default_library", "//crypto/bls:go_default_library",
"//crypto/rand:go_default_library", "//crypto/rand:go_default_library",
"//encoding/bytesutil:go_default_library", "//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz/equality:go_default_library",
"//monitoring/tracing:go_default_library", "//monitoring/tracing:go_default_library",
"//network/forks:go_default_library", "//network/forks:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
@@ -195,7 +195,7 @@ go_test(
"//crypto/bls:go_default_library", "//crypto/bls:go_default_library",
"//crypto/rand:go_default_library", "//crypto/rand:go_default_library",
"//encoding/bytesutil:go_default_library", "//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz/equality:go_default_library",
"//network/forks:go_default_library", "//network/forks:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/attestation:go_default_library", "//proto/prysm/v1alpha1/attestation:go_default_library",

View File

@@ -15,7 +15,7 @@ import (
"github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/crypto/rand" "github.com/prysmaticlabs/prysm/crypto/rand"
"github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/bytesutil"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
"github.com/prysmaticlabs/prysm/monitoring/tracing" "github.com/prysmaticlabs/prysm/monitoring/tracing"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block"
"github.com/prysmaticlabs/prysm/time/slots" "github.com/prysmaticlabs/prysm/time/slots"
@@ -328,7 +328,7 @@ func (s *Service) deleteBlockFromPendingQueue(slot types.Slot, b block.SignedBea
newBlks := make([]block.SignedBeaconBlock, 0, len(blks)) newBlks := make([]block.SignedBeaconBlock, 0, len(blks))
for _, blk := range blks { for _, blk := range blks {
if ssz.DeepEqual(blk.Proto(), b.Proto()) { if equality.DeepEqual(blk.Proto(), b.Proto()) {
continue continue
} }
newBlks = append(newBlks, blk) newBlks = append(newBlks, blk)

View File

@@ -17,7 +17,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/p2p" "github.com/prysmaticlabs/prysm/beacon-chain/p2p"
p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing" p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
"github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
@@ -125,7 +125,7 @@ func TestMetadataRPCHandler_SendsMetadata(t *testing.T) {
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID()) metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
assert.NoError(t, err) assert.NoError(t, err)
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata) t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
} }
@@ -213,7 +213,7 @@ func TestMetadataRPCHandler_SendsMetadataAltair(t *testing.T) {
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID()) metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
assert.NoError(t, err) assert.NoError(t, err)
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata) t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
} }

View File

@@ -74,7 +74,8 @@ type Flags struct {
CorrectlyInsertOrphanedAtts bool CorrectlyInsertOrphanedAtts bool
CorrectlyPruneCanonicalAtts bool CorrectlyPruneCanonicalAtts bool
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct. EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines.
// KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have // KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have
// changed on disk. This feature is for advanced use cases only. // changed on disk. This feature is for advanced use cases only.
@@ -222,6 +223,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
logEnabled(enableNativeState) logEnabled(enableNativeState)
cfg.EnableNativeState = true cfg.EnableNativeState = true
} }
if ctx.Bool(enableVecHTR.Name) {
logEnabled(enableVecHTR)
cfg.EnableVectorizedHTR = true
}
Init(cfg) Init(cfg)
} }

View File

@@ -134,11 +134,16 @@ var (
Name: "enable-native-state", Name: "enable-native-state",
Usage: "Enables representing the beacon state as a pure Go struct.", Usage: "Enables representing the beacon state as a pure Go struct.",
} }
enableVecHTR = &cli.BoolFlag{
Name: "enable-vectorized-htr",
Usage: "Enables new go sha256 library which utilizes optimized routines for merkle trees",
}
) )
// devModeFlags holds list of flags that are set when development mode is on. // devModeFlags holds list of flags that are set when development mode is on.
var devModeFlags = []cli.Flag{ var devModeFlags = []cli.Flag{
enablePeerScorer, enablePeerScorer,
enableVecHTR,
} }
// ValidatorFlags contains a list of all the feature flags that apply to the validator client. // ValidatorFlags contains a list of all the feature flags that apply to the validator client.
@@ -183,6 +188,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
disableBatchGossipVerification, disableBatchGossipVerification,
disableBalanceTrieComputation, disableBalanceTrieComputation,
enableNativeState, enableNativeState,
enableVecHTR,
}...) }...)
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E. // E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.

View File

@@ -0,0 +1,9 @@
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["hashtree.go"],
importpath = "github.com/prysmaticlabs/prysm/crypto/hash/htr",
visibility = ["//visibility:public"],
deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"],
)

View File

@@ -0,0 +1,17 @@
package htr
import (
"github.com/prysmaticlabs/gohashtree"
)
// VectorizedSha256 takes a list of roots and hashes them using CPU
// specific vector instructions. Depending on host machine's specific
// hardware configuration, using this routine can lead to a significant
// performance improvement compared to the default method of hashing
// lists.
func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) {
err := gohashtree.Hash(outputList, inputList)
if err != nil {
panic(err)
}
}

View File

@@ -2982,6 +2982,13 @@ def prysm_deps():
sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=", sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=",
version = "v0.0.0-20210809151128-385d8c5e3fb7", version = "v0.0.0-20210809151128-385d8c5e3fb7",
) )
go_repository(
name = "com_github_prysmaticlabs_gohashtree",
importpath = "github.com/prysmaticlabs/gohashtree",
sum = "h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=",
version = "v0.0.0-20220208111633-0606f58df32f",
)
go_repository( go_repository(
name = "com_github_prysmaticlabs_prombbolt", name = "com_github_prysmaticlabs_prombbolt",
importpath = "github.com/prysmaticlabs/prombbolt", importpath = "github.com/prysmaticlabs/prombbolt",

View File

@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library( go_library(
name = "go_default_library", name = "go_default_library",
srcs = [ srcs = [
"deep_equal.go",
"hashers.go", "hashers.go",
"helpers.go", "helpers.go",
"htrutils.go", "htrutils.go",
@@ -12,16 +11,16 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz", importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//config/features:go_default_library",
"//config/fieldparams:go_default_library", "//config/fieldparams:go_default_library",
"//container/trie:go_default_library", "//container/trie:go_default_library",
"//crypto/hash:go_default_library", "//crypto/hash:go_default_library",
"//crypto/hash/htr:go_default_library",
"//encoding/bytesutil:go_default_library", "//encoding/bytesutil:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"@com_github_minio_sha256_simd//:go_default_library", "@com_github_minio_sha256_simd//:go_default_library",
"@com_github_pkg_errors//:go_default_library", "@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
], ],
) )
@@ -29,7 +28,6 @@ go_test(
name = "go_default_test", name = "go_default_test",
size = "small", size = "small",
srcs = [ srcs = [
"deep_equal_test.go",
"hashers_test.go", "hashers_test.go",
"helpers_test.go", "helpers_test.go",
"htrutils_test.go", "htrutils_test.go",

View File

@@ -0,0 +1,22 @@
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["deep_equal.go"],
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz/equality",
visibility = ["//visibility:public"],
deps = [
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["deep_equal_test.go"],
deps = [
":go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//testing/assert:go_default_library",
],
)

View File

@@ -1,4 +1,4 @@
package ssz package equality
import ( import (
"reflect" "reflect"

View File

@@ -1,34 +1,34 @@
package ssz_test package equality_test
import ( import (
"testing" "testing"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert" "github.com/prysmaticlabs/prysm/testing/assert"
) )
func TestDeepEqualBasicTypes(t *testing.T) { func TestDeepEqualBasicTypes(t *testing.T) {
assert.Equal(t, true, ssz.DeepEqual(true, true)) assert.Equal(t, true, equality.DeepEqual(true, true))
assert.Equal(t, false, ssz.DeepEqual(true, false)) assert.Equal(t, false, equality.DeepEqual(true, false))
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222))) assert.Equal(t, true, equality.DeepEqual(byte(222), byte(222)))
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111))) assert.Equal(t, false, equality.DeepEqual(byte(222), byte(111)))
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890))) assert.Equal(t, true, equality.DeepEqual(uint64(1234567890), uint64(1234567890)))
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210))) assert.Equal(t, false, equality.DeepEqual(uint64(1234567890), uint64(987653210)))
assert.Equal(t, true, ssz.DeepEqual("hello", "hello")) assert.Equal(t, true, equality.DeepEqual("hello", "hello"))
assert.Equal(t, false, ssz.DeepEqual("hello", "world")) assert.Equal(t, false, equality.DeepEqual("hello", "world"))
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3})) assert.Equal(t, true, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4})) assert.Equal(t, false, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
var nilSlice1, nilSlice2 []byte var nilSlice1, nilSlice2 []byte
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2)) assert.Equal(t, true, equality.DeepEqual(nilSlice1, nilSlice2))
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{})) assert.Equal(t, true, equality.DeepEqual(nilSlice1, []byte{}))
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3})) assert.Equal(t, true, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4})) assert.Equal(t, false, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
} }
func TestDeepEqualStructs(t *testing.T) { func TestDeepEqualStructs(t *testing.T) {
@@ -39,8 +39,8 @@ func TestDeepEqualStructs(t *testing.T) {
store1 := Store{uint64(1234), nil} store1 := Store{uint64(1234), nil}
store2 := Store{uint64(1234), []byte{}} store2 := Store{uint64(1234), []byte{}}
store3 := Store{uint64(4321), []byte{}} store3 := Store{uint64(4321), []byte{}}
assert.Equal(t, true, ssz.DeepEqual(store1, store2)) assert.Equal(t, true, equality.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3)) assert.Equal(t, false, equality.DeepEqual(store1, store3))
} }
func TestDeepEqualStructs_Unexported(t *testing.T) { func TestDeepEqualStructs_Unexported(t *testing.T) {
@@ -53,14 +53,14 @@ func TestDeepEqualStructs_Unexported(t *testing.T) {
store2 := Store{uint64(1234), []byte{}, "hi there"} store2 := Store{uint64(1234), []byte{}, "hi there"}
store3 := Store{uint64(4321), []byte{}, "wow"} store3 := Store{uint64(4321), []byte{}, "wow"}
store4 := Store{uint64(4321), []byte{}, "bow wow"} store4 := Store{uint64(4321), []byte{}, "bow wow"}
assert.Equal(t, true, ssz.DeepEqual(store1, store2)) assert.Equal(t, true, equality.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3)) assert.Equal(t, false, equality.DeepEqual(store1, store3))
assert.Equal(t, false, ssz.DeepEqual(store3, store4)) assert.Equal(t, false, equality.DeepEqual(store3, store4))
} }
func TestDeepEqualProto(t *testing.T) { func TestDeepEqualProto(t *testing.T) {
var fork1, fork2 *ethpb.Fork var fork1, fork2 *ethpb.Fork
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2)) assert.Equal(t, true, equality.DeepEqual(fork1, fork2))
fork1 = &ethpb.Fork{ fork1 = &ethpb.Fork{
PreviousVersion: []byte{123}, PreviousVersion: []byte{123},
@@ -72,8 +72,8 @@ func TestDeepEqualProto(t *testing.T) {
CurrentVersion: []byte{125}, CurrentVersion: []byte{125},
Epoch: 1234567890, Epoch: 1234567890,
} }
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1)) assert.Equal(t, true, equality.DeepEqual(fork1, fork1))
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2)) assert.Equal(t, false, equality.DeepEqual(fork1, fork2))
checkpoint1 := &ethpb.Checkpoint{ checkpoint1 := &ethpb.Checkpoint{
Epoch: 1234567890, Epoch: 1234567890,
@@ -83,7 +83,7 @@ func TestDeepEqualProto(t *testing.T) {
Epoch: 1234567890, Epoch: 1234567890,
Root: nil, Root: nil,
} }
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2)) assert.Equal(t, true, equality.DeepEqual(checkpoint1, checkpoint2))
} }
func Test_IsProto(t *testing.T) { func Test_IsProto(t *testing.T) {
@@ -125,7 +125,7 @@ func Test_IsProto(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := ssz.IsProto(tt.item); got != tt.want { if got := equality.IsProto(tt.item); got != tt.want {
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want) t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
} }
}) })

View File

@@ -8,6 +8,7 @@ import (
"github.com/minio/sha256-simd" "github.com/minio/sha256-simd"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prysmaticlabs/go-bitfield" "github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/bytesutil"
) )
@@ -50,6 +51,9 @@ func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]
if count > limit { if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit") return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
} }
if features.Get().EnableVectorizedHTR {
return MerkleizeList(chunks, limit), nil
}
hashFn := NewHasherFunc(hasher) hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte { leafIndexer := func(i uint64) []byte {
return chunks[i] return chunks[i]
@@ -62,6 +66,9 @@ func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint6
if count > limit { if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit") return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
} }
if features.Get().EnableVectorizedHTR {
return MerkleizeVector(chunks, limit), nil
}
hashFn := NewHasherFunc(hasher) hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte { leafIndexer := func(i uint64) []byte {
return chunks[i][:] return chunks[i][:]

View File

@@ -2,6 +2,7 @@ package ssz
import ( import (
"github.com/prysmaticlabs/prysm/container/trie" "github.com/prysmaticlabs/prysm/container/trie"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
) )
// Merkleize.go is mostly a directly copy of the same filename from // Merkleize.go is mostly a directly copy of the same filename from
@@ -196,3 +197,40 @@ func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []by
return return
} }
// MerkleizeVector uses our optimized routine to hash a list of 32-byte
// elements.
func MerkleizeVector(elements [][32]byte, length uint64) [32]byte {
depth := Depth(length)
// Return zerohash at depth
if len(elements) == 0 {
return trie.ZeroHashes[depth]
}
for i := 0; i < int(depth); i++ {
layerLen := len(elements)
oddNodeLength := layerLen%2 == 1
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
elements = append(elements, zerohash)
}
outputLen := len(elements) / 2
htr.VectorizedSha256(elements, elements)
elements = elements[:outputLen]
}
return elements[0]
}
// MerkleizeList uses our optimized routine to hash a 2d-list of
// elements.
func MerkleizeList(elements [][]byte, length uint64) [32]byte {
depth := Depth(length)
// Return zerohash at depth
if len(elements) == 0 {
return trie.ZeroHashes[depth]
}
newElems := make([][32]byte, len(elements))
for i := range elements {
copy(newElems[i][:], elements[i])
}
return MerkleizeVector(newElems, length)
}

1
go.mod
View File

@@ -253,6 +253,7 @@ require (
github.com/holiman/uint256 v1.2.0 github.com/holiman/uint256 v1.2.0
github.com/peterh/liner v1.2.0 // indirect github.com/peterh/liner v1.2.0 // indirect
github.com/prometheus/tsdb v0.10.0 // indirect github.com/prometheus/tsdb v0.10.0 // indirect
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
google.golang.org/api v0.34.0 // indirect google.golang.org/api v0.34.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect

2
go.sum
View File

@@ -1123,6 +1123,8 @@ github.com/prysmaticlabs/fastssz v0.0.0-20220110145812-fafb696cae88/go.mod h1:AS
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s= github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw= github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4= github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f/go.mod h1:4pWaT30XoEx1j8KNJf3TV+E3mQkaufn7mf+jRNb/Fuk=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU= github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=

View File

@@ -28,7 +28,7 @@ go_test(
deps = [ deps = [
"//config/params:go_default_library", "//config/params:go_default_library",
"//crypto/bls:go_default_library", "//crypto/bls:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz/equality:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/attestation/aggregation:go_default_library", "//proto/prysm/v1alpha1/attestation/aggregation:go_default_library",
"//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library", "//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library",

View File

@@ -8,7 +8,7 @@ import (
"github.com/prysmaticlabs/go-bitfield" "github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation"
aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing" aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing"
@@ -48,7 +48,7 @@ func TestAggregateAttestations_AggregatePair(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
got, err := AggregatePair(tt.a1, tt.a2) got, err := AggregatePair(tt.a1, tt.a2)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, true, ssz.DeepEqual(got, tt.want)) require.Equal(t, true, equality.DeepEqual(got, tt.want))
} }
} }

View File

@@ -6,7 +6,7 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/testing/assertions", importpath = "github.com/prysmaticlabs/prysm/testing/assertions",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//encoding/ssz:go_default_library", "//encoding/ssz/equality:go_default_library",
"@com_github_d4l3k_messagediff//:go_default_library", "@com_github_d4l3k_messagediff//:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library", "@com_github_sirupsen_logrus//hooks/test:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//proto:go_default_library",

View File

@@ -9,7 +9,7 @@ import (
"strings" "strings"
"github.com/d4l3k/messagediff" "github.com/d4l3k/messagediff"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
"github.com/sirupsen/logrus/hooks/test" "github.com/sirupsen/logrus/hooks/test"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@@ -61,7 +61,7 @@ func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
// DeepSSZEqual compares values using ssz.DeepEqual. // DeepSSZEqual compares values using ssz.DeepEqual.
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) { func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !ssz.DeepEqual(expected, actual) { if !equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...) errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2) _, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual) diff, _ := messagediff.PrettyDiff(expected, actual)
@@ -71,7 +71,7 @@ func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
// DeepNotSSZEqual compares values using ssz.DeepEqual. // DeepNotSSZEqual compares values using ssz.DeepEqual.
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) { func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if ssz.DeepEqual(expected, actual) { if equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...) errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2) _, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual) loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)

View File

@@ -12,7 +12,7 @@ go_library(
deps = [ deps = [
"//beacon-chain/core/transition:go_default_library", "//beacon-chain/core/transition:go_default_library",
"//beacon-chain/state/v1:go_default_library", "//beacon-chain/state/v1:go_default_library",
"//encoding/ssz:go_default_library", "//encoding/ssz/equality:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/wrapper:go_default_library", "//proto/prysm/v1alpha1/wrapper:go_default_library",
"//runtime/version:go_default_library", "//runtime/version:go_default_library",

View File

@@ -13,7 +13,7 @@ import (
"github.com/kr/pretty" "github.com/kr/pretty"
"github.com/prysmaticlabs/prysm/beacon-chain/core/transition" "github.com/prysmaticlabs/prysm/beacon-chain/core/transition"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1" v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
"github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
"github.com/prysmaticlabs/prysm/runtime/version" "github.com/prysmaticlabs/prysm/runtime/version"
@@ -191,7 +191,7 @@ func main() {
if err := dataFetcher(expectedPostStatePath, expectedState); err != nil { if err := dataFetcher(expectedPostStatePath, expectedState); err != nil {
log.Fatal(err) log.Fatal(err)
} }
if !ssz.DeepEqual(expectedState, postState.InnerStateUnsafe()) { if !equality.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe()) diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe())
log.Errorf("Derived state differs from provided post state: %s", diff) log.Errorf("Derived state differs from provided post state: %s", diff)
} }