From f97db3b738438c81d75251e93650cdb16948d1b6 Mon Sep 17 00:00:00 2001 From: Potuz Date: Fri, 21 Jul 2023 20:36:20 -0400 Subject: [PATCH] Different parallel hashing (#12639) * Paralellize hashing of large lists * add unit test * add file * do not parallelize on low processor count * revert minimal proc count --------- Co-authored-by: Nishant Das --- beacon-chain/state/stateutil/BUILD.bazel | 1 + .../state/stateutil/field_root_validator.go | 38 ++++++++++++++++--- .../stateutil/field_root_validator_test.go | 27 +++++++++++++ .../state/stateutil/sync_committee.root.go | 3 +- beacon-chain/state/stateutil/trie_helpers.go | 8 +--- crypto/hash/htr/BUILD.bazel | 10 ++++- crypto/hash/htr/hashtree.go | 34 ++++++++++++++++- crypto/hash/htr/hashtree_test.go | 27 +++++++++++++ encoding/ssz/merkleize.go | 4 +- 9 files changed, 132 insertions(+), 20 deletions(-) create mode 100644 crypto/hash/htr/hashtree_test.go diff --git a/beacon-chain/state/stateutil/BUILD.bazel b/beacon-chain/state/stateutil/BUILD.bazel index f9847d6f98..653f3c1a30 100644 --- a/beacon-chain/state/stateutil/BUILD.bazel +++ b/beacon-chain/state/stateutil/BUILD.bazel @@ -34,6 +34,7 @@ go_library( "//math:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "@com_github_pkg_errors//:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", ], ) diff --git a/beacon-chain/state/stateutil/field_root_validator.go b/beacon-chain/state/stateutil/field_root_validator.go index 4f88c24642..6f5d562fc2 100644 --- a/beacon-chain/state/stateutil/field_root_validator.go +++ b/beacon-chain/state/stateutil/field_root_validator.go @@ -3,12 +3,15 @@ package stateutil import ( "bytes" "encoding/binary" + "runtime" + "sync" "github.com/pkg/errors" fieldparams "github.com/prysmaticlabs/prysm/v4/config/fieldparams" "github.com/prysmaticlabs/prysm/v4/crypto/hash/htr" "github.com/prysmaticlabs/prysm/v4/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1" + "github.com/sirupsen/logrus" ) const ( @@ -51,6 +54,20 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) { return res, nil } +func hashValidatorHelper(validators []*ethpb.Validator, roots [][32]byte, j int, groupSize int, wg *sync.WaitGroup) { + defer wg.Done() + for i := 0; i < groupSize; i++ { + fRoots, err := ValidatorFieldRoots(validators[j*groupSize+i]) + if err != nil { + logrus.WithError(err).Error("could not get validator field roots") + return + } + for k, root := range fRoots { + roots[(j*groupSize+i)*validatorFieldRoots+k] = root + } + } +} + // OptimizedValidatorRoots uses an optimized routine with gohashtree in order to // derive a list of validator roots from a list of validator objects. func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) { @@ -58,14 +75,25 @@ func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) if len(validators) == 0 { return [][32]byte{}, nil } - roots := make([][32]byte, 0, len(validators)*validatorFieldRoots) - for i := 0; i < len(validators); i++ { + wg := sync.WaitGroup{} + n := runtime.GOMAXPROCS(0) + rootsSize := len(validators) * validatorFieldRoots + groupSize := len(validators) / n + roots := make([][32]byte, rootsSize) + wg.Add(n - 1) + for j := 0; j < n-1; j++ { + go hashValidatorHelper(validators, roots, j, groupSize, &wg) + } + for i := (n - 1) * groupSize; i < len(validators); i++ { fRoots, err := ValidatorFieldRoots(validators[i]) if err != nil { return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization") } - roots = append(roots, fRoots...) + for k, root := range fRoots { + roots[i*validatorFieldRoots+k] = root + } } + wg.Wait() // A validator's tree can represented with a depth of 3. As log2(8) = 3 // Using this property we can lay out all the individual fields of a @@ -73,9 +101,7 @@ func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) for i := 0; i < validatorTreeDepth; i++ { // Overwrite input lists as we are hashing by level // and only need the highest level to proceed. - outputLen := len(roots) / 2 - htr.VectorizedSha256(roots, roots) - roots = roots[:outputLen] + roots = htr.VectorizedSha256(roots) } return roots, nil } diff --git a/beacon-chain/state/stateutil/field_root_validator_test.go b/beacon-chain/state/stateutil/field_root_validator_test.go index b261cd0370..ee3b29745c 100644 --- a/beacon-chain/state/stateutil/field_root_validator_test.go +++ b/beacon-chain/state/stateutil/field_root_validator_test.go @@ -3,11 +3,13 @@ package stateutil import ( "reflect" "strings" + "sync" "testing" mathutil "github.com/prysmaticlabs/prysm/v4/math" ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/v4/testing/assert" + "github.com/prysmaticlabs/prysm/v4/testing/require" ) func TestValidatorConstants(t *testing.T) { @@ -30,3 +32,28 @@ func TestValidatorConstants(t *testing.T) { _, err := ValidatorRegistryRoot([]*ethpb.Validator{v}) assert.NoError(t, err) } + +func TestHashValidatorHelper(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) + v := ðpb.Validator{} + valList := make([]*ethpb.Validator, 10*validatorFieldRoots) + for i := range valList { + valList[i] = v + } + roots := make([][32]byte, len(valList)) + hashValidatorHelper(valList, roots, 2, 2, &wg) + for i := 0; i < 4*validatorFieldRoots; i++ { + require.Equal(t, [32]byte{}, roots[i]) + } + emptyValRoots, err := ValidatorFieldRoots(v) + require.NoError(t, err) + for i := 4; i < 6; i++ { + for j := 0; j < validatorFieldRoots; j++ { + require.Equal(t, emptyValRoots[j], roots[i*validatorFieldRoots+j]) + } + } + for i := 6 * validatorFieldRoots; i < 10*validatorFieldRoots; i++ { + require.Equal(t, [32]byte{}, roots[i]) + } +} diff --git a/beacon-chain/state/stateutil/sync_committee.root.go b/beacon-chain/state/stateutil/sync_committee.root.go index 8384295137..9038325c10 100644 --- a/beacon-chain/state/stateutil/sync_committee.root.go +++ b/beacon-chain/state/stateutil/sync_committee.root.go @@ -48,8 +48,7 @@ func merkleizePubkey(pubkey []byte) ([32]byte, error) { if err != nil { return [32]byte{}, err } - outputChunk := make([][32]byte, 1) - htr.VectorizedSha256(chunks, outputChunk) + outputChunk := htr.VectorizedSha256(chunks) return outputChunk[0], nil } diff --git a/beacon-chain/state/stateutil/trie_helpers.go b/beacon-chain/state/stateutil/trie_helpers.go index a68bef4b7f..9c8815cc78 100644 --- a/beacon-chain/state/stateutil/trie_helpers.go +++ b/beacon-chain/state/stateutil/trie_helpers.go @@ -71,9 +71,7 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte { } layers[i+1] = make([]*[32]byte, layerLen/2) - newElems := make([][32]byte, layerLen/2) - htr.VectorizedSha256(elements, newElems) - elements = newElems + elements = htr.VectorizedSha256(elements) for j := range elements { layers[i+1][j] = &elements[j] } @@ -295,9 +293,7 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte) ([][][32]byt if !math.IsPowerOf2(uint64(len(hashLayer))) { return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer)) } - newLayer := make([][32]byte, len(hashLayer)/2) - htr.VectorizedSha256(hashLayer, newLayer) - hashLayer = newLayer + hashLayer = htr.VectorizedSha256(hashLayer) layers[i] = hashLayer i++ } diff --git a/crypto/hash/htr/BUILD.bazel b/crypto/hash/htr/BUILD.bazel index 4c20e2c140..74bc3e35be 100644 --- a/crypto/hash/htr/BUILD.bazel +++ b/crypto/hash/htr/BUILD.bazel @@ -1,4 +1,4 @@ -load("@prysm//tools/go:def.bzl", "go_library") +load("@prysm//tools/go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", @@ -7,3 +7,11 @@ go_library( visibility = ["//visibility:public"], deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"], ) + +go_test( + name = "go_default_test", + size = "small", + srcs = ["hashtree_test.go"], + embed = [":go_default_library"], + deps = ["//testing/require:go_default_library"], +) diff --git a/crypto/hash/htr/hashtree.go b/crypto/hash/htr/hashtree.go index da50f326c7..1ffb68b038 100644 --- a/crypto/hash/htr/hashtree.go +++ b/crypto/hash/htr/hashtree.go @@ -1,17 +1,47 @@ package htr import ( + "runtime" + "sync" + "github.com/prysmaticlabs/gohashtree" ) +const minSliceSizeToParallelize = 5000 + +func hashParallel(inputList [][32]byte, outputList [][32]byte, wg *sync.WaitGroup) { + defer wg.Done() + err := gohashtree.Hash(outputList, inputList) + if err != nil { + panic(err) + } +} + // VectorizedSha256 takes a list of roots and hashes them using CPU // specific vector instructions. Depending on host machine's specific // hardware configuration, using this routine can lead to a significant // performance improvement compared to the default method of hashing // lists. -func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) { - err := gohashtree.Hash(outputList, inputList) +func VectorizedSha256(inputList [][32]byte) [][32]byte { + outputList := make([][32]byte, len(inputList)/2) + if len(inputList) < minSliceSizeToParallelize { + err := gohashtree.Hash(outputList, inputList) + if err != nil { + panic(err) + } + return outputList + } + n := runtime.GOMAXPROCS(0) - 1 + wg := sync.WaitGroup{} + wg.Add(n) + groupSize := len(inputList) / (2 * (n + 1)) + for j := 0; j < n; j++ { + go hashParallel(inputList[j*2*groupSize:(j+1)*2*groupSize], outputList[j*groupSize:], &wg) + } + err := gohashtree.Hash(outputList[n*groupSize:], inputList[n*2*groupSize:]) if err != nil { panic(err) } + wg.Wait() + return outputList } diff --git a/crypto/hash/htr/hashtree_test.go b/crypto/hash/htr/hashtree_test.go new file mode 100644 index 0000000000..3faa66f1de --- /dev/null +++ b/crypto/hash/htr/hashtree_test.go @@ -0,0 +1,27 @@ +package htr + +import ( + "sync" + "testing" + + "github.com/prysmaticlabs/prysm/v4/testing/require" +) + +func Test_VectorizedSha256(t *testing.T) { + largeSlice := make([][32]byte, 32*minSliceSizeToParallelize) + secondLargeSlice := make([][32]byte, 32*minSliceSizeToParallelize) + hash1 := make([][32]byte, 16*minSliceSizeToParallelize) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + tempHash := VectorizedSha256(largeSlice) + copy(hash1, tempHash) + }() + wg.Wait() + hash2 := VectorizedSha256(secondLargeSlice) + require.Equal(t, len(hash1), len(hash2)) + for i, r := range hash1 { + require.Equal(t, r, hash2[i]) + } +} diff --git a/encoding/ssz/merkleize.go b/encoding/ssz/merkleize.go index 1b22b159ef..af27f3106b 100644 --- a/encoding/ssz/merkleize.go +++ b/encoding/ssz/merkleize.go @@ -213,9 +213,7 @@ func MerkleizeVector(elements [][32]byte, length uint64) [32]byte { zerohash := trie.ZeroHashes[i] elements = append(elements, zerohash) } - outputLen := len(elements) / 2 - htr.VectorizedSha256(elements, elements) - elements = elements[:outputLen] + elements = htr.VectorizedSha256(elements) } return elements[0] }