Different parallel hashing (#12639)

* Paralellize hashing of large lists

* add unit test

* add file

* do not parallelize on low processor count

* revert minimal proc count

---------

Co-authored-by: Nishant Das <nishdas93@gmail.com>
This commit is contained in:
Potuz
2023-07-21 20:36:20 -04:00
committed by GitHub
parent 43378ae8d5
commit f97db3b738
9 changed files with 132 additions and 20 deletions

View File

@@ -34,6 +34,7 @@ go_library(
"//math:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
],
)

View File

@@ -3,12 +3,15 @@ package stateutil
import (
"bytes"
"encoding/binary"
"runtime"
"sync"
"github.com/pkg/errors"
fieldparams "github.com/prysmaticlabs/prysm/v4/config/fieldparams"
"github.com/prysmaticlabs/prysm/v4/crypto/hash/htr"
"github.com/prysmaticlabs/prysm/v4/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
"github.com/sirupsen/logrus"
)
const (
@@ -51,6 +54,20 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) {
return res, nil
}
func hashValidatorHelper(validators []*ethpb.Validator, roots [][32]byte, j int, groupSize int, wg *sync.WaitGroup) {
defer wg.Done()
for i := 0; i < groupSize; i++ {
fRoots, err := ValidatorFieldRoots(validators[j*groupSize+i])
if err != nil {
logrus.WithError(err).Error("could not get validator field roots")
return
}
for k, root := range fRoots {
roots[(j*groupSize+i)*validatorFieldRoots+k] = root
}
}
}
// OptimizedValidatorRoots uses an optimized routine with gohashtree in order to
// derive a list of validator roots from a list of validator objects.
func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) {
@@ -58,14 +75,25 @@ func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error)
if len(validators) == 0 {
return [][32]byte{}, nil
}
roots := make([][32]byte, 0, len(validators)*validatorFieldRoots)
for i := 0; i < len(validators); i++ {
wg := sync.WaitGroup{}
n := runtime.GOMAXPROCS(0)
rootsSize := len(validators) * validatorFieldRoots
groupSize := len(validators) / n
roots := make([][32]byte, rootsSize)
wg.Add(n - 1)
for j := 0; j < n-1; j++ {
go hashValidatorHelper(validators, roots, j, groupSize, &wg)
}
for i := (n - 1) * groupSize; i < len(validators); i++ {
fRoots, err := ValidatorFieldRoots(validators[i])
if err != nil {
return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization")
}
roots = append(roots, fRoots...)
for k, root := range fRoots {
roots[i*validatorFieldRoots+k] = root
}
}
wg.Wait()
// A validator's tree can represented with a depth of 3. As log2(8) = 3
// Using this property we can lay out all the individual fields of a
@@ -73,9 +101,7 @@ func OptimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error)
for i := 0; i < validatorTreeDepth; i++ {
// Overwrite input lists as we are hashing by level
// and only need the highest level to proceed.
outputLen := len(roots) / 2
htr.VectorizedSha256(roots, roots)
roots = roots[:outputLen]
roots = htr.VectorizedSha256(roots)
}
return roots, nil
}

View File

@@ -3,11 +3,13 @@ package stateutil
import (
"reflect"
"strings"
"sync"
"testing"
mathutil "github.com/prysmaticlabs/prysm/v4/math"
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/v4/testing/assert"
"github.com/prysmaticlabs/prysm/v4/testing/require"
)
func TestValidatorConstants(t *testing.T) {
@@ -30,3 +32,28 @@ func TestValidatorConstants(t *testing.T) {
_, err := ValidatorRegistryRoot([]*ethpb.Validator{v})
assert.NoError(t, err)
}
func TestHashValidatorHelper(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
v := &ethpb.Validator{}
valList := make([]*ethpb.Validator, 10*validatorFieldRoots)
for i := range valList {
valList[i] = v
}
roots := make([][32]byte, len(valList))
hashValidatorHelper(valList, roots, 2, 2, &wg)
for i := 0; i < 4*validatorFieldRoots; i++ {
require.Equal(t, [32]byte{}, roots[i])
}
emptyValRoots, err := ValidatorFieldRoots(v)
require.NoError(t, err)
for i := 4; i < 6; i++ {
for j := 0; j < validatorFieldRoots; j++ {
require.Equal(t, emptyValRoots[j], roots[i*validatorFieldRoots+j])
}
}
for i := 6 * validatorFieldRoots; i < 10*validatorFieldRoots; i++ {
require.Equal(t, [32]byte{}, roots[i])
}
}

View File

@@ -48,8 +48,7 @@ func merkleizePubkey(pubkey []byte) ([32]byte, error) {
if err != nil {
return [32]byte{}, err
}
outputChunk := make([][32]byte, 1)
htr.VectorizedSha256(chunks, outputChunk)
outputChunk := htr.VectorizedSha256(chunks)
return outputChunk[0], nil
}

View File

@@ -71,9 +71,7 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte {
}
layers[i+1] = make([]*[32]byte, layerLen/2)
newElems := make([][32]byte, layerLen/2)
htr.VectorizedSha256(elements, newElems)
elements = newElems
elements = htr.VectorizedSha256(elements)
for j := range elements {
layers[i+1][j] = &elements[j]
}
@@ -295,9 +293,7 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte) ([][][32]byt
if !math.IsPowerOf2(uint64(len(hashLayer))) {
return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer))
}
newLayer := make([][32]byte, len(hashLayer)/2)
htr.VectorizedSha256(hashLayer, newLayer)
hashLayer = newLayer
hashLayer = htr.VectorizedSha256(hashLayer)
layers[i] = hashLayer
i++
}

View File

@@ -1,4 +1,4 @@
load("@prysm//tools/go:def.bzl", "go_library")
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
@@ -7,3 +7,11 @@ go_library(
visibility = ["//visibility:public"],
deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"],
)
go_test(
name = "go_default_test",
size = "small",
srcs = ["hashtree_test.go"],
embed = [":go_default_library"],
deps = ["//testing/require:go_default_library"],
)

View File

@@ -1,17 +1,47 @@
package htr
import (
"runtime"
"sync"
"github.com/prysmaticlabs/gohashtree"
)
const minSliceSizeToParallelize = 5000
func hashParallel(inputList [][32]byte, outputList [][32]byte, wg *sync.WaitGroup) {
defer wg.Done()
err := gohashtree.Hash(outputList, inputList)
if err != nil {
panic(err)
}
}
// VectorizedSha256 takes a list of roots and hashes them using CPU
// specific vector instructions. Depending on host machine's specific
// hardware configuration, using this routine can lead to a significant
// performance improvement compared to the default method of hashing
// lists.
func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) {
err := gohashtree.Hash(outputList, inputList)
func VectorizedSha256(inputList [][32]byte) [][32]byte {
outputList := make([][32]byte, len(inputList)/2)
if len(inputList) < minSliceSizeToParallelize {
err := gohashtree.Hash(outputList, inputList)
if err != nil {
panic(err)
}
return outputList
}
n := runtime.GOMAXPROCS(0) - 1
wg := sync.WaitGroup{}
wg.Add(n)
groupSize := len(inputList) / (2 * (n + 1))
for j := 0; j < n; j++ {
go hashParallel(inputList[j*2*groupSize:(j+1)*2*groupSize], outputList[j*groupSize:], &wg)
}
err := gohashtree.Hash(outputList[n*groupSize:], inputList[n*2*groupSize:])
if err != nil {
panic(err)
}
wg.Wait()
return outputList
}

View File

@@ -0,0 +1,27 @@
package htr
import (
"sync"
"testing"
"github.com/prysmaticlabs/prysm/v4/testing/require"
)
func Test_VectorizedSha256(t *testing.T) {
largeSlice := make([][32]byte, 32*minSliceSizeToParallelize)
secondLargeSlice := make([][32]byte, 32*minSliceSizeToParallelize)
hash1 := make([][32]byte, 16*minSliceSizeToParallelize)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
tempHash := VectorizedSha256(largeSlice)
copy(hash1, tempHash)
}()
wg.Wait()
hash2 := VectorizedSha256(secondLargeSlice)
require.Equal(t, len(hash1), len(hash2))
for i, r := range hash1 {
require.Equal(t, r, hash2[i])
}
}

View File

@@ -213,9 +213,7 @@ func MerkleizeVector(elements [][32]byte, length uint64) [32]byte {
zerohash := trie.ZeroHashes[i]
elements = append(elements, zerohash)
}
outputLen := len(elements) / 2
htr.VectorizedSha256(elements, elements)
elements = elements[:outputLen]
elements = htr.VectorizedSha256(elements)
}
return elements[0]
}