Integration of Vectorized Sha256 In Prysm (#10166)

* add changes

* fix for vectorize

* fix bug

* add new bench

* use new algorithms

* add latest updates

* save progress

* hack even more

* add more changes

* change library

* go mod

* fix deps

* fix dumb bug

* add flag and remove redundant code

* clean up better

* remove those ones

* clean up benches

* clean up benches

* cleanup

* gaz

* revert change

* potuz's review

* potuz's review

* potuz's review

* gaz

* potuz's review

* remove cyclical import

* revert ide changes

* potuz's review

* return
This commit is contained in:
Nishant Das
2022-02-28 21:56:12 +08:00
committed by GitHub
parent 12ba8f3645
commit 339540274b
30 changed files with 491 additions and 95 deletions

View File

@@ -258,3 +258,21 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) {
_, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey))
assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey")
}
func BenchmarkBeaconState(b *testing.B) {
testState, _ := util.DeterministicGenesisState(b, 16000)
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(b, err)
b.Run("Vectorized SHA256", func(b *testing.B) {
st, err := v1.InitializeFromProtoUnsafe(pbState)
require.NoError(b, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(b, err)
})
b.Run("Current SHA256", func(b *testing.B) {
_, err := pbState.HashTreeRoot()
require.NoError(b, err)
})
}

View File

@@ -32,10 +32,12 @@ go_library(
],
deps = [
"//beacon-chain/core/transition/stateutils:go_default_library",
"//config/features:go_default_library",
"//config/fieldparams:go_default_library",
"//config/params:go_default_library",
"//container/trie:go_default_library",
"//crypto/hash:go_default_library",
"//crypto/hash/htr:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//math:go_default_library",
@@ -51,6 +53,7 @@ go_test(
srcs = [
"benchmark_test.go",
"field_root_test.go",
"field_root_validator_test.go",
"reference_bench_test.go",
"state_root_test.go",
"trie_helpers_test.go",
@@ -58,11 +61,14 @@ go_test(
],
embed = [":go_default_library"],
deps = [
"//beacon-chain/state:go_default_library",
"//config/features:go_default_library",
"//config/fieldparams:go_default_library",
"//config/params:go_default_library",
"//crypto/hash:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//math:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//runtime/interop:go_default_library",
"//testing/assert:go_default_library",

View File

@@ -5,12 +5,25 @@ import (
"encoding/binary"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
)
const (
// number of field roots for the validator object.
validatorFieldRoots = 8
// Depth of tree representation of an individual
// validator.
// NumOfRoots = 2 ^ (TreeDepth)
// 8 = 2 ^ 3
validatorTreeDepth = 3
)
// ValidatorRegistryRoot computes the HashTreeRoot Merkleization of
// a list of validator structs according to the Ethereum
// Simple Serialize specification.
@@ -19,14 +32,20 @@ func ValidatorRegistryRoot(vals []*ethpb.Validator) ([32]byte, error) {
}
func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
roots := make([][32]byte, len(validators))
hasher := hash.CustomSHA256Hasher()
for i := 0; i < len(validators); i++ {
val, err := validatorRoot(hasher, validators[i])
var err error
var roots [][32]byte
if features.Get().EnableVectorizedHTR {
roots, err = optimizedValidatorRoots(validators)
if err != nil {
return [32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
return [32]byte{}, err
}
} else {
roots, err = validatorRoots(hasher, validators)
if err != nil {
return [32]byte{}, err
}
roots[i] = val
}
validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit)
@@ -45,6 +64,42 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
return res, nil
}
func validatorRoots(hasher func([]byte) [32]byte, validators []*ethpb.Validator) ([][32]byte, error) {
roots := make([][32]byte, len(validators))
for i := 0; i < len(validators); i++ {
val, err := validatorRoot(hasher, validators[i])
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
}
roots[i] = val
}
return roots, nil
}
func optimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) {
roots := make([][32]byte, 0, len(validators)*validatorFieldRoots)
hasher := hash.CustomSHA256Hasher()
for i := 0; i < len(validators); i++ {
fRoots, err := ValidatorFieldRoots(hasher, validators[i])
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
}
roots = append(roots, fRoots...)
}
// A validator's tree can represented with a depth of 3. As log2(8) = 3
// Using this property we can lay out all the individual fields of a
// validator and hash them in single level using our vectorized routine.
for i := 0; i < validatorTreeDepth; i++ {
// Overwrite input lists as we are hashing by level
// and only need the highest level to proceed.
outputLen := len(roots) / 2
htr.VectorizedSha256(roots, roots)
roots = roots[:outputLen]
}
return roots, nil
}
func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
if validator == nil {
return [32]byte{}, errors.New("nil validator")

View File

@@ -0,0 +1,32 @@
package stateutil
import (
"reflect"
"strings"
"testing"
mathutil "github.com/prysmaticlabs/prysm/math"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
)
func TestValidatorConstants(t *testing.T) {
v := &ethpb.Validator{}
refV := reflect.ValueOf(v).Elem()
numFields := refV.NumField()
numOfValFields := 0
for i := 0; i < numFields; i++ {
if strings.Contains(refV.Type().Field(i).Name, "state") ||
strings.Contains(refV.Type().Field(i).Name, "sizeCache") ||
strings.Contains(refV.Type().Field(i).Name, "unknownFields") {
continue
}
numOfValFields++
}
assert.Equal(t, validatorFieldRoots, numOfValFields)
assert.Equal(t, uint64(validatorFieldRoots), mathutil.PowerOf2(validatorTreeDepth))
_, err := ValidatorRegistryRoot([]*ethpb.Validator{v})
assert.NoError(t, err)
}

View File

@@ -5,8 +5,10 @@ import (
"encoding/binary"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/container/trie"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/math"
)
@@ -61,25 +63,43 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte {
layers[0] = transformedLeaves
buffer := bytes.NewBuffer([]byte{})
buffer.Grow(64)
for i := 0; i < int(depth); i++ {
oddNodeLength := len(layers[i])%2 == 1
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
layers[i] = append(layers[i], &zerohash)
layerLen := len(layers[i])
oddNodeLength := layerLen%2 == 1
if features.Get().EnableVectorizedHTR {
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
elements = append(elements, zerohash)
layerLen++
}
layers[i+1] = make([]*[32]byte, layerLen/2)
newElems := make([][32]byte, layerLen/2)
htr.VectorizedSha256(elements, newElems)
elements = newElems
for j := range elements {
layers[i+1][j] = &elements[j]
}
} else {
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
layers[i] = append(layers[i], &zerohash)
}
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
for j := 0; j < len(layers[i]); j += 2 {
buffer.Write(layers[i][j][:])
buffer.Write(layers[i][j+1][:])
concat := hasher(buffer.Bytes())
updatedValues = append(updatedValues, &concat)
buffer.Reset()
}
// remove zerohash node from tree
if oddNodeLength {
layers[i] = layers[i][:len(layers[i])-1]
}
layers[i+1] = updatedValues
}
updatedValues := make([]*[32]byte, 0, len(layers[i])/2)
for j := 0; j < len(layers[i]); j += 2 {
buffer.Write(layers[i][j][:])
buffer.Write(layers[i][j+1][:])
concat := hasher(buffer.Bytes())
updatedValues = append(updatedValues, &concat)
buffer.Reset()
}
// remove zerohash node from tree
if oddNodeLength {
layers[i] = layers[i][:len(layers[i])-1]
}
layers[i+1] = updatedValues
}
return layers
}
@@ -277,18 +297,24 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte,
chunkBuffer := bytes.NewBuffer([]byte{})
chunkBuffer.Grow(64)
for len(hashLayer) > 1 && i < len(layers) {
layer := make([][32]byte, len(hashLayer)/2)
if !math.IsPowerOf2(uint64(len(hashLayer))) {
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
}
for j := 0; j < len(hashLayer); j += 2 {
chunkBuffer.Write(hashLayer[j][:])
chunkBuffer.Write(hashLayer[j+1][:])
hashedChunk := hasher(chunkBuffer.Bytes())
layer[j/2] = hashedChunk
chunkBuffer.Reset()
if features.Get().EnableVectorizedHTR {
newLayer := make([][32]byte, len(hashLayer)/2)
htr.VectorizedSha256(hashLayer, newLayer)
hashLayer = newLayer
} else {
layer := make([][32]byte, len(hashLayer)/2)
for j := 0; j < len(hashLayer); j += 2 {
chunkBuffer.Write(hashLayer[j][:])
chunkBuffer.Write(hashLayer[j+1][:])
hashedChunk := hasher(chunkBuffer.Bytes())
layer[j/2] = hashedChunk
chunkBuffer.Reset()
}
hashLayer = layer
}
hashLayer = layer
layers[i] = hashLayer
i++
}

View File

@@ -4,7 +4,9 @@ import (
"testing"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
@@ -18,15 +20,56 @@ func TestReturnTrieLayer_OK(t *testing.T) {
newState, _ := util.DeterministicGenesisState(t, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(t, err)
blockRts := newState.BlockRoots()
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
roots := retrieveBlockRoots(newState)
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(t, err)
newRoot := *layers[len(layers)-1][0]
assert.Equal(t, root, newRoot)
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
layers, err = stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(t, err)
lastRoot := *layers[len(layers)-1][0]
assert.Equal(t, root, lastRoot)
}
func BenchmarkReturnTrieLayer_NormalAlgorithm(b *testing.B) {
newState, _ := util.DeterministicGenesisState(b, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(b, err)
roots := retrieveBlockRoots(newState)
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(b, err)
newRoot := *layers[len(layers)-1][0]
assert.Equal(b, root, newRoot)
}
}
func BenchmarkReturnTrieLayer_VectorizedAlgorithm(b *testing.B) {
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
newState, _ := util.DeterministicGenesisState(b, 32)
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
require.NoError(b, err)
roots := retrieveBlockRoots(newState)
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
assert.NoError(b, err)
newRoot := *layers[len(layers)-1][0]
assert.Equal(b, root, newRoot)
}
}
func TestReturnTrieLayerVariable_OK(t *testing.T) {
@@ -46,15 +89,73 @@ func TestReturnTrieLayerVariable_OK(t *testing.T) {
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(t, err)
assert.Equal(t, root, newRoot)
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
layers = stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
lastRoot := *layers[len(layers)-1][0]
lastRoot, err = stateutil.AddInMixin(lastRoot, uint64(len(validators)))
require.NoError(t, err)
assert.Equal(t, root, lastRoot)
}
func BenchmarkReturnTrieLayerVariable_NormalAlgorithm(b *testing.B) {
newState, _ := util.DeterministicGenesisState(b, 16000)
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
require.NoError(b, err)
hasher := hash.CustomSHA256Hasher()
validators := newState.Validators()
roots := make([][32]byte, 0, len(validators))
for _, val := range validators {
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
require.NoError(b, err)
roots = append(roots, rt)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
newRoot := *layers[len(layers)-1][0]
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(b, err)
assert.Equal(b, root, newRoot)
}
}
func BenchmarkReturnTrieLayerVariable_VectorizedAlgorithm(b *testing.B) {
flags := &features.Flags{}
flags.EnableVectorizedHTR = true
reset := features.InitWithReset(flags)
defer reset()
newState, _ := util.DeterministicGenesisState(b, 16000)
root, err := stateutil.ValidatorRegistryRoot(newState.Validators())
require.NoError(b, err)
hasher := hash.CustomSHA256Hasher()
validators := newState.Validators()
roots := make([][32]byte, 0, len(validators))
for _, val := range validators {
rt, err := stateutil.ValidatorRootWithHasher(hasher, val)
require.NoError(b, err)
roots = append(roots, rt)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit)
newRoot := *layers[len(layers)-1][0]
newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators)))
require.NoError(b, err)
assert.Equal(b, root, newRoot)
}
}
func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) {
newState, _ := util.DeterministicGenesisState(t, 32)
blockRts := newState.BlockRoots()
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
roots := retrieveBlockRoots(newState)
layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots)))
require.NoError(t, err)
@@ -120,3 +221,12 @@ func TestMerkleizeTrieLeaves_BadHashLayer(t *testing.T) {
})
assert.ErrorContains(t, "hash layer is a non power of 2", err)
}
func retrieveBlockRoots(b state.BeaconState) [][32]byte {
blockRts := b.BlockRoots()
roots := make([][32]byte, 0, len(blockRts))
for _, rt := range blockRts {
roots = append(roots, bytesutil.ToBytes32(rt))
}
return roots
}

View File

@@ -4,8 +4,10 @@ import (
"encoding/binary"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/config/features"
fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
"github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
@@ -14,6 +16,16 @@ import (
// ValidatorRootWithHasher describes a method from which the hash tree root
// of a validator is returned.
func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) {
fieldRoots, err := ValidatorFieldRoots(hasher, validator)
if err != nil {
return [32]byte{}, err
}
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
}
// ValidatorFieldRoots describes a method from which the hash tree root
// of a validator is returned.
func ValidatorFieldRoots(hasher ssz.HashFn, validator *ethpb.Validator) ([][32]byte, error) {
var fieldRoots [][32]byte
if validator != nil {
pubkey := bytesutil.ToBytes48(validator.PublicKey)
@@ -40,18 +52,25 @@ func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32
binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch))
// Public key.
pubKeyChunks, err := ssz.Pack([][]byte{pubkey[:]})
pubKeyChunks, err := ssz.PackByChunk([][]byte{pubkey[:]})
if err != nil {
return [32]byte{}, err
return [][32]byte{}, err
}
pubKeyRoot, err := ssz.BitwiseMerkleize(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
if err != nil {
return [32]byte{}, err
var pubKeyRoot [32]byte
if features.Get().EnableVectorizedHTR {
outputChunk := make([][32]byte, 1)
htr.VectorizedSha256(pubKeyChunks, outputChunk)
pubKeyRoot = outputChunk[0]
} else {
pubKeyRoot, err = ssz.BitwiseMerkleizeArrays(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks)))
if err != nil {
return [][32]byte{}, err
}
}
fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf,
activationBuf, exitBuf, withdrawalBuf}
}
return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
return fieldRoots, nil
}
// Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of

View File

@@ -177,6 +177,24 @@ func TestBeaconState_HashTreeRoot(t *testing.T) {
}
}
func BenchmarkBeaconState(b *testing.B) {
testState, _ := util.DeterministicGenesisState(b, 16000)
pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(b, err)
b.Run("Vectorized SHA256", func(b *testing.B) {
st, err := v1.InitializeFromProtoUnsafe(pbState)
require.NoError(b, err)
_, err = st.HashTreeRoot(context.Background())
assert.NoError(b, err)
})
b.Run("Current SHA256", func(b *testing.B) {
_, err := pbState.HashTreeRoot()
require.NoError(b, err)
})
}
func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
testState, _ := util.DeterministicGenesisState(t, 64)

View File

@@ -86,7 +86,7 @@ go_library(
"//crypto/bls:go_default_library",
"//crypto/rand:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//encoding/ssz/equality:go_default_library",
"//monitoring/tracing:go_default_library",
"//network/forks:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
@@ -195,7 +195,7 @@ go_test(
"//crypto/bls:go_default_library",
"//crypto/rand:go_default_library",
"//encoding/bytesutil:go_default_library",
"//encoding/ssz:go_default_library",
"//encoding/ssz/equality:go_default_library",
"//network/forks:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/attestation:go_default_library",

View File

@@ -15,7 +15,7 @@ import (
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/crypto/rand"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
"github.com/prysmaticlabs/prysm/monitoring/tracing"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block"
"github.com/prysmaticlabs/prysm/time/slots"
@@ -328,7 +328,7 @@ func (s *Service) deleteBlockFromPendingQueue(slot types.Slot, b block.SignedBea
newBlks := make([]block.SignedBeaconBlock, 0, len(blks))
for _, blk := range blks {
if ssz.DeepEqual(blk.Proto(), b.Proto()) {
if equality.DeepEqual(blk.Proto(), b.Proto()) {
continue
}
newBlks = append(newBlks, blk)

View File

@@ -17,7 +17,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/p2p"
p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
@@ -125,7 +125,7 @@ func TestMetadataRPCHandler_SendsMetadata(t *testing.T) {
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
assert.NoError(t, err)
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
}
@@ -213,7 +213,7 @@ func TestMetadataRPCHandler_SendsMetadataAltair(t *testing.T) {
metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID())
assert.NoError(t, err)
if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) {
t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata)
}

View File

@@ -74,7 +74,8 @@ type Flags struct {
CorrectlyInsertOrphanedAtts bool
CorrectlyPruneCanonicalAtts bool
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct.
EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines.
// KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have
// changed on disk. This feature is for advanced use cases only.
@@ -222,6 +223,10 @@ func ConfigureBeaconChain(ctx *cli.Context) {
logEnabled(enableNativeState)
cfg.EnableNativeState = true
}
if ctx.Bool(enableVecHTR.Name) {
logEnabled(enableVecHTR)
cfg.EnableVectorizedHTR = true
}
Init(cfg)
}

View File

@@ -134,11 +134,16 @@ var (
Name: "enable-native-state",
Usage: "Enables representing the beacon state as a pure Go struct.",
}
enableVecHTR = &cli.BoolFlag{
Name: "enable-vectorized-htr",
Usage: "Enables new go sha256 library which utilizes optimized routines for merkle trees",
}
)
// devModeFlags holds list of flags that are set when development mode is on.
var devModeFlags = []cli.Flag{
enablePeerScorer,
enableVecHTR,
}
// ValidatorFlags contains a list of all the feature flags that apply to the validator client.
@@ -183,6 +188,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{
disableBatchGossipVerification,
disableBalanceTrieComputation,
enableNativeState,
enableVecHTR,
}...)
// E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E.

View File

@@ -0,0 +1,9 @@
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["hashtree.go"],
importpath = "github.com/prysmaticlabs/prysm/crypto/hash/htr",
visibility = ["//visibility:public"],
deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"],
)

View File

@@ -0,0 +1,17 @@
package htr
import (
"github.com/prysmaticlabs/gohashtree"
)
// VectorizedSha256 takes a list of roots and hashes them using CPU
// specific vector instructions. Depending on host machine's specific
// hardware configuration, using this routine can lead to a significant
// performance improvement compared to the default method of hashing
// lists.
func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) {
err := gohashtree.Hash(outputList, inputList)
if err != nil {
panic(err)
}
}

View File

@@ -2982,6 +2982,13 @@ def prysm_deps():
sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=",
version = "v0.0.0-20210809151128-385d8c5e3fb7",
)
go_repository(
name = "com_github_prysmaticlabs_gohashtree",
importpath = "github.com/prysmaticlabs/gohashtree",
sum = "h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=",
version = "v0.0.0-20220208111633-0606f58df32f",
)
go_repository(
name = "com_github_prysmaticlabs_prombbolt",
importpath = "github.com/prysmaticlabs/prombbolt",

View File

@@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"deep_equal.go",
"hashers.go",
"helpers.go",
"htrutils.go",
@@ -12,16 +11,16 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
visibility = ["//visibility:public"],
deps = [
"//config/features:go_default_library",
"//config/fieldparams:go_default_library",
"//container/trie:go_default_library",
"//crypto/hash:go_default_library",
"//crypto/hash/htr:go_default_library",
"//encoding/bytesutil:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"@com_github_minio_sha256_simd//:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
@@ -29,7 +28,6 @@ go_test(
name = "go_default_test",
size = "small",
srcs = [
"deep_equal_test.go",
"hashers_test.go",
"helpers_test.go",
"htrutils_test.go",

View File

@@ -0,0 +1,22 @@
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["deep_equal.go"],
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz/equality",
visibility = ["//visibility:public"],
deps = [
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["deep_equal_test.go"],
deps = [
":go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//testing/assert:go_default_library",
],
)

View File

@@ -1,4 +1,4 @@
package ssz
package equality
import (
"reflect"

View File

@@ -1,34 +1,34 @@
package ssz_test
package equality_test
import (
"testing"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
)
func TestDeepEqualBasicTypes(t *testing.T) {
assert.Equal(t, true, ssz.DeepEqual(true, true))
assert.Equal(t, false, ssz.DeepEqual(true, false))
assert.Equal(t, true, equality.DeepEqual(true, true))
assert.Equal(t, false, equality.DeepEqual(true, false))
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222)))
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111)))
assert.Equal(t, true, equality.DeepEqual(byte(222), byte(222)))
assert.Equal(t, false, equality.DeepEqual(byte(222), byte(111)))
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890)))
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210)))
assert.Equal(t, true, equality.DeepEqual(uint64(1234567890), uint64(1234567890)))
assert.Equal(t, false, equality.DeepEqual(uint64(1234567890), uint64(987653210)))
assert.Equal(t, true, ssz.DeepEqual("hello", "hello"))
assert.Equal(t, false, ssz.DeepEqual("hello", "world"))
assert.Equal(t, true, equality.DeepEqual("hello", "hello"))
assert.Equal(t, false, equality.DeepEqual("hello", "world"))
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
assert.Equal(t, true, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
assert.Equal(t, false, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
var nilSlice1, nilSlice2 []byte
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2))
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{}))
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
assert.Equal(t, true, equality.DeepEqual(nilSlice1, nilSlice2))
assert.Equal(t, true, equality.DeepEqual(nilSlice1, []byte{}))
assert.Equal(t, true, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
assert.Equal(t, false, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
}
func TestDeepEqualStructs(t *testing.T) {
@@ -39,8 +39,8 @@ func TestDeepEqualStructs(t *testing.T) {
store1 := Store{uint64(1234), nil}
store2 := Store{uint64(1234), []byte{}}
store3 := Store{uint64(4321), []byte{}}
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
assert.Equal(t, true, equality.DeepEqual(store1, store2))
assert.Equal(t, false, equality.DeepEqual(store1, store3))
}
func TestDeepEqualStructs_Unexported(t *testing.T) {
@@ -53,14 +53,14 @@ func TestDeepEqualStructs_Unexported(t *testing.T) {
store2 := Store{uint64(1234), []byte{}, "hi there"}
store3 := Store{uint64(4321), []byte{}, "wow"}
store4 := Store{uint64(4321), []byte{}, "bow wow"}
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
assert.Equal(t, false, ssz.DeepEqual(store3, store4))
assert.Equal(t, true, equality.DeepEqual(store1, store2))
assert.Equal(t, false, equality.DeepEqual(store1, store3))
assert.Equal(t, false, equality.DeepEqual(store3, store4))
}
func TestDeepEqualProto(t *testing.T) {
var fork1, fork2 *ethpb.Fork
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2))
assert.Equal(t, true, equality.DeepEqual(fork1, fork2))
fork1 = &ethpb.Fork{
PreviousVersion: []byte{123},
@@ -72,8 +72,8 @@ func TestDeepEqualProto(t *testing.T) {
CurrentVersion: []byte{125},
Epoch: 1234567890,
}
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1))
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2))
assert.Equal(t, true, equality.DeepEqual(fork1, fork1))
assert.Equal(t, false, equality.DeepEqual(fork1, fork2))
checkpoint1 := &ethpb.Checkpoint{
Epoch: 1234567890,
@@ -83,7 +83,7 @@ func TestDeepEqualProto(t *testing.T) {
Epoch: 1234567890,
Root: nil,
}
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2))
assert.Equal(t, true, equality.DeepEqual(checkpoint1, checkpoint2))
}
func Test_IsProto(t *testing.T) {
@@ -125,7 +125,7 @@ func Test_IsProto(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ssz.IsProto(tt.item); got != tt.want {
if got := equality.IsProto(tt.item); got != tt.want {
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
}
})

View File

@@ -8,6 +8,7 @@ import (
"github.com/minio/sha256-simd"
"github.com/pkg/errors"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/config/features"
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
)
@@ -50,6 +51,9 @@ func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]
if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
}
if features.Get().EnableVectorizedHTR {
return MerkleizeList(chunks, limit), nil
}
hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte {
return chunks[i]
@@ -62,6 +66,9 @@ func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint6
if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
}
if features.Get().EnableVectorizedHTR {
return MerkleizeVector(chunks, limit), nil
}
hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte {
return chunks[i][:]

View File

@@ -2,6 +2,7 @@ package ssz
import (
"github.com/prysmaticlabs/prysm/container/trie"
"github.com/prysmaticlabs/prysm/crypto/hash/htr"
)
// Merkleize.go is mostly a directly copy of the same filename from
@@ -196,3 +197,40 @@ func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []by
return
}
// MerkleizeVector uses our optimized routine to hash a list of 32-byte
// elements.
func MerkleizeVector(elements [][32]byte, length uint64) [32]byte {
depth := Depth(length)
// Return zerohash at depth
if len(elements) == 0 {
return trie.ZeroHashes[depth]
}
for i := 0; i < int(depth); i++ {
layerLen := len(elements)
oddNodeLength := layerLen%2 == 1
if oddNodeLength {
zerohash := trie.ZeroHashes[i]
elements = append(elements, zerohash)
}
outputLen := len(elements) / 2
htr.VectorizedSha256(elements, elements)
elements = elements[:outputLen]
}
return elements[0]
}
// MerkleizeList uses our optimized routine to hash a 2d-list of
// elements.
func MerkleizeList(elements [][]byte, length uint64) [32]byte {
depth := Depth(length)
// Return zerohash at depth
if len(elements) == 0 {
return trie.ZeroHashes[depth]
}
newElems := make([][32]byte, len(elements))
for i := range elements {
copy(newElems[i][:], elements[i])
}
return MerkleizeVector(newElems, length)
}

1
go.mod
View File

@@ -253,6 +253,7 @@ require (
github.com/holiman/uint256 v1.2.0
github.com/peterh/liner v1.2.0 // indirect
github.com/prometheus/tsdb v0.10.0 // indirect
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect
google.golang.org/api v0.34.0 // indirect
google.golang.org/appengine v1.6.7 // indirect

2
go.sum
View File

@@ -1123,6 +1123,8 @@ github.com/prysmaticlabs/fastssz v0.0.0-20220110145812-fafb696cae88/go.mod h1:AS
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=
github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f/go.mod h1:4pWaT30XoEx1j8KNJf3TV+E3mQkaufn7mf+jRNb/Fuk=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=

View File

@@ -28,7 +28,7 @@ go_test(
deps = [
"//config/params:go_default_library",
"//crypto/bls:go_default_library",
"//encoding/ssz:go_default_library",
"//encoding/ssz/equality:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/attestation/aggregation:go_default_library",
"//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library",

View File

@@ -8,7 +8,7 @@ import (
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation"
aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing"
@@ -48,7 +48,7 @@ func TestAggregateAttestations_AggregatePair(t *testing.T) {
for _, tt := range tests {
got, err := AggregatePair(tt.a1, tt.a2)
require.NoError(t, err)
require.Equal(t, true, ssz.DeepEqual(got, tt.want))
require.Equal(t, true, equality.DeepEqual(got, tt.want))
}
}

View File

@@ -6,7 +6,7 @@ go_library(
importpath = "github.com/prysmaticlabs/prysm/testing/assertions",
visibility = ["//visibility:public"],
deps = [
"//encoding/ssz:go_default_library",
"//encoding/ssz/equality:go_default_library",
"@com_github_d4l3k_messagediff//:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",

View File

@@ -9,7 +9,7 @@ import (
"strings"
"github.com/d4l3k/messagediff"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
"github.com/sirupsen/logrus/hooks/test"
"google.golang.org/protobuf/proto"
)
@@ -61,7 +61,7 @@ func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
// DeepSSZEqual compares values using ssz.DeepEqual.
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !ssz.DeepEqual(expected, actual) {
if !equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
@@ -71,7 +71,7 @@ func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg
// DeepNotSSZEqual compares values using ssz.DeepEqual.
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if ssz.DeepEqual(expected, actual) {
if equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)

View File

@@ -12,7 +12,7 @@ go_library(
deps = [
"//beacon-chain/core/transition:go_default_library",
"//beacon-chain/state/v1:go_default_library",
"//encoding/ssz:go_default_library",
"//encoding/ssz/equality:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/wrapper:go_default_library",
"//runtime/version:go_default_library",

View File

@@ -13,7 +13,7 @@ import (
"github.com/kr/pretty"
"github.com/prysmaticlabs/prysm/beacon-chain/core/transition"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/encoding/ssz/equality"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
"github.com/prysmaticlabs/prysm/runtime/version"
@@ -191,7 +191,7 @@ func main() {
if err := dataFetcher(expectedPostStatePath, expectedState); err != nil {
log.Fatal(err)
}
if !ssz.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
if !equality.DeepEqual(expectedState, postState.InnerStateUnsafe()) {
diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe())
log.Errorf("Derived state differs from provided post state: %s", diff)
}