From 339540274b2436f941459703778041f626ff9afe Mon Sep 17 00:00:00 2001 From: Nishant Das Date: Mon, 28 Feb 2022 21:56:12 +0800 Subject: [PATCH] Integration of Vectorized Sha256 In Prysm (#10166) * add changes * fix for vectorize * fix bug * add new bench * use new algorithms * add latest updates * save progress * hack even more * add more changes * change library * go mod * fix deps * fix dumb bug * add flag and remove redundant code * clean up better * remove those ones * clean up benches * clean up benches * cleanup * gaz * revert change * potuz's review * potuz's review * potuz's review * gaz * potuz's review * remove cyclical import * revert ide changes * potuz's review * return --- .../state/state-native/v1/state_trie_test.go | 18 +++ beacon-chain/state/stateutil/BUILD.bazel | 6 + .../state/stateutil/field_root_validator.go | 65 ++++++++- .../stateutil/field_root_validator_test.go | 32 +++++ beacon-chain/state/stateutil/trie_helpers.go | 76 ++++++---- .../state/stateutil/trie_helpers_test.go | 130 ++++++++++++++++-- .../state/stateutil/validator_root.go | 31 ++++- beacon-chain/state/v1/state_trie_test.go | 18 +++ beacon-chain/sync/BUILD.bazel | 4 +- beacon-chain/sync/pending_blocks_queue.go | 4 +- beacon-chain/sync/rpc_metadata_test.go | 6 +- config/features/config.go | 7 +- config/features/flags.go | 6 + crypto/hash/htr/BUILD.bazel | 9 ++ crypto/hash/htr/hashtree.go | 17 +++ deps.bzl | 7 + encoding/ssz/BUILD.bazel | 6 +- encoding/ssz/equality/BUILD.bazel | 22 +++ encoding/ssz/{ => equality}/deep_equal.go | 2 +- .../ssz/{ => equality}/deep_equal_test.go | 52 +++---- encoding/ssz/helpers.go | 7 + encoding/ssz/merkleize.go | 38 +++++ go.mod | 1 + go.sum | 2 + .../aggregation/attestations/BUILD.bazel | 2 +- .../attestations/attestations_test.go | 4 +- testing/assertions/BUILD.bazel | 2 +- testing/assertions/assertions.go | 6 +- tools/pcli/BUILD.bazel | 2 +- tools/pcli/main.go | 4 +- 30 files changed, 491 insertions(+), 95 deletions(-) create mode 100644 beacon-chain/state/stateutil/field_root_validator_test.go create mode 100644 crypto/hash/htr/BUILD.bazel create mode 100644 crypto/hash/htr/hashtree.go create mode 100644 encoding/ssz/equality/BUILD.bazel rename encoding/ssz/{ => equality}/deep_equal.go (99%) rename encoding/ssz/{ => equality}/deep_equal_test.go (54%) diff --git a/beacon-chain/state/state-native/v1/state_trie_test.go b/beacon-chain/state/state-native/v1/state_trie_test.go index e24122ba91..361d62162f 100644 --- a/beacon-chain/state/state-native/v1/state_trie_test.go +++ b/beacon-chain/state/state-native/v1/state_trie_test.go @@ -258,3 +258,21 @@ func TestBeaconState_AppendValidator_DoesntMutateCopy(t *testing.T) { _, ok := st1.ValidatorIndexByPubkey(bytesutil.ToBytes48(val.PublicKey)) assert.Equal(t, false, ok, "Expected no validator index to be present in st1 for the newly inserted pubkey") } + +func BenchmarkBeaconState(b *testing.B) { + testState, _ := util.DeterministicGenesisState(b, 16000) + pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(b, err) + + b.Run("Vectorized SHA256", func(b *testing.B) { + st, err := v1.InitializeFromProtoUnsafe(pbState) + require.NoError(b, err) + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(b, err) + }) + + b.Run("Current SHA256", func(b *testing.B) { + _, err := pbState.HashTreeRoot() + require.NoError(b, err) + }) +} diff --git a/beacon-chain/state/stateutil/BUILD.bazel b/beacon-chain/state/stateutil/BUILD.bazel index 2afc80e077..9e3300ca4a 100644 --- a/beacon-chain/state/stateutil/BUILD.bazel +++ b/beacon-chain/state/stateutil/BUILD.bazel @@ -32,10 +32,12 @@ go_library( ], deps = [ "//beacon-chain/core/transition/stateutils:go_default_library", + "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", "//container/trie:go_default_library", "//crypto/hash:go_default_library", + "//crypto/hash/htr:go_default_library", "//encoding/bytesutil:go_default_library", "//encoding/ssz:go_default_library", "//math:go_default_library", @@ -51,6 +53,7 @@ go_test( srcs = [ "benchmark_test.go", "field_root_test.go", + "field_root_validator_test.go", "reference_bench_test.go", "state_root_test.go", "trie_helpers_test.go", @@ -58,11 +61,14 @@ go_test( ], embed = [":go_default_library"], deps = [ + "//beacon-chain/state:go_default_library", + "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", "//crypto/hash:go_default_library", "//encoding/bytesutil:go_default_library", "//encoding/ssz:go_default_library", + "//math:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//runtime/interop:go_default_library", "//testing/assert:go_default_library", diff --git a/beacon-chain/state/stateutil/field_root_validator.go b/beacon-chain/state/stateutil/field_root_validator.go index 72d43be234..c7c205fc1b 100644 --- a/beacon-chain/state/stateutil/field_root_validator.go +++ b/beacon-chain/state/stateutil/field_root_validator.go @@ -5,12 +5,25 @@ import ( "encoding/binary" "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/config/features" fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams" "github.com/prysmaticlabs/prysm/crypto/hash" + "github.com/prysmaticlabs/prysm/crypto/hash/htr" "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" ) +const ( + // number of field roots for the validator object. + validatorFieldRoots = 8 + + // Depth of tree representation of an individual + // validator. + // NumOfRoots = 2 ^ (TreeDepth) + // 8 = 2 ^ 3 + validatorTreeDepth = 3 +) + // ValidatorRegistryRoot computes the HashTreeRoot Merkleization of // a list of validator structs according to the Ethereum // Simple Serialize specification. @@ -19,14 +32,20 @@ func ValidatorRegistryRoot(vals []*ethpb.Validator) ([32]byte, error) { } func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) { - roots := make([][32]byte, len(validators)) hasher := hash.CustomSHA256Hasher() - for i := 0; i < len(validators); i++ { - val, err := validatorRoot(hasher, validators[i]) + + var err error + var roots [][32]byte + if features.Get().EnableVectorizedHTR { + roots, err = optimizedValidatorRoots(validators) if err != nil { - return [32]byte{}, errors.Wrap(err, "could not compute validators merkleization") + return [32]byte{}, err + } + } else { + roots, err = validatorRoots(hasher, validators) + if err != nil { + return [32]byte{}, err } - roots[i] = val } validatorsRootsRoot, err := ssz.BitwiseMerkleizeArrays(hasher, roots, uint64(len(roots)), fieldparams.ValidatorRegistryLimit) @@ -45,6 +64,42 @@ func validatorRegistryRoot(validators []*ethpb.Validator) ([32]byte, error) { return res, nil } +func validatorRoots(hasher func([]byte) [32]byte, validators []*ethpb.Validator) ([][32]byte, error) { + roots := make([][32]byte, len(validators)) + for i := 0; i < len(validators); i++ { + val, err := validatorRoot(hasher, validators[i]) + if err != nil { + return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization") + } + roots[i] = val + } + return roots, nil +} + +func optimizedValidatorRoots(validators []*ethpb.Validator) ([][32]byte, error) { + roots := make([][32]byte, 0, len(validators)*validatorFieldRoots) + hasher := hash.CustomSHA256Hasher() + for i := 0; i < len(validators); i++ { + fRoots, err := ValidatorFieldRoots(hasher, validators[i]) + if err != nil { + return [][32]byte{}, errors.Wrap(err, "could not compute validators merkleization") + } + roots = append(roots, fRoots...) + } + + // A validator's tree can represented with a depth of 3. As log2(8) = 3 + // Using this property we can lay out all the individual fields of a + // validator and hash them in single level using our vectorized routine. + for i := 0; i < validatorTreeDepth; i++ { + // Overwrite input lists as we are hashing by level + // and only need the highest level to proceed. + outputLen := len(roots) / 2 + htr.VectorizedSha256(roots, roots) + roots = roots[:outputLen] + } + return roots, nil +} + func validatorRoot(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) { if validator == nil { return [32]byte{}, errors.New("nil validator") diff --git a/beacon-chain/state/stateutil/field_root_validator_test.go b/beacon-chain/state/stateutil/field_root_validator_test.go new file mode 100644 index 0000000000..43ff89e8c3 --- /dev/null +++ b/beacon-chain/state/stateutil/field_root_validator_test.go @@ -0,0 +1,32 @@ +package stateutil + +import ( + "reflect" + "strings" + "testing" + + mathutil "github.com/prysmaticlabs/prysm/math" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/testing/assert" +) + +func TestValidatorConstants(t *testing.T) { + v := ðpb.Validator{} + refV := reflect.ValueOf(v).Elem() + numFields := refV.NumField() + numOfValFields := 0 + + for i := 0; i < numFields; i++ { + if strings.Contains(refV.Type().Field(i).Name, "state") || + strings.Contains(refV.Type().Field(i).Name, "sizeCache") || + strings.Contains(refV.Type().Field(i).Name, "unknownFields") { + continue + } + numOfValFields++ + } + assert.Equal(t, validatorFieldRoots, numOfValFields) + assert.Equal(t, uint64(validatorFieldRoots), mathutil.PowerOf2(validatorTreeDepth)) + + _, err := ValidatorRegistryRoot([]*ethpb.Validator{v}) + assert.NoError(t, err) +} diff --git a/beacon-chain/state/stateutil/trie_helpers.go b/beacon-chain/state/stateutil/trie_helpers.go index e0c0cdf5eb..7ff845d192 100644 --- a/beacon-chain/state/stateutil/trie_helpers.go +++ b/beacon-chain/state/stateutil/trie_helpers.go @@ -5,8 +5,10 @@ import ( "encoding/binary" "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/container/trie" "github.com/prysmaticlabs/prysm/crypto/hash" + "github.com/prysmaticlabs/prysm/crypto/hash/htr" "github.com/prysmaticlabs/prysm/encoding/ssz" "github.com/prysmaticlabs/prysm/math" ) @@ -61,25 +63,43 @@ func ReturnTrieLayerVariable(elements [][32]byte, length uint64) [][]*[32]byte { layers[0] = transformedLeaves buffer := bytes.NewBuffer([]byte{}) buffer.Grow(64) + for i := 0; i < int(depth); i++ { - oddNodeLength := len(layers[i])%2 == 1 - if oddNodeLength { - zerohash := trie.ZeroHashes[i] - layers[i] = append(layers[i], &zerohash) + layerLen := len(layers[i]) + oddNodeLength := layerLen%2 == 1 + if features.Get().EnableVectorizedHTR { + if oddNodeLength { + zerohash := trie.ZeroHashes[i] + elements = append(elements, zerohash) + layerLen++ + } + + layers[i+1] = make([]*[32]byte, layerLen/2) + newElems := make([][32]byte, layerLen/2) + htr.VectorizedSha256(elements, newElems) + elements = newElems + for j := range elements { + layers[i+1][j] = &elements[j] + } + } else { + if oddNodeLength { + zerohash := trie.ZeroHashes[i] + layers[i] = append(layers[i], &zerohash) + } + updatedValues := make([]*[32]byte, 0, len(layers[i])/2) + for j := 0; j < len(layers[i]); j += 2 { + buffer.Write(layers[i][j][:]) + buffer.Write(layers[i][j+1][:]) + concat := hasher(buffer.Bytes()) + updatedValues = append(updatedValues, &concat) + buffer.Reset() + } + // remove zerohash node from tree + if oddNodeLength { + layers[i] = layers[i][:len(layers[i])-1] + } + layers[i+1] = updatedValues } - updatedValues := make([]*[32]byte, 0, len(layers[i])/2) - for j := 0; j < len(layers[i]); j += 2 { - buffer.Write(layers[i][j][:]) - buffer.Write(layers[i][j+1][:]) - concat := hasher(buffer.Bytes()) - updatedValues = append(updatedValues, &concat) - buffer.Reset() - } - // remove zerohash node from tree - if oddNodeLength { - layers[i] = layers[i][:len(layers[i])-1] - } - layers[i+1] = updatedValues } return layers } @@ -277,18 +297,24 @@ func MerkleizeTrieLeaves(layers [][][32]byte, hashLayer [][32]byte, chunkBuffer := bytes.NewBuffer([]byte{}) chunkBuffer.Grow(64) for len(hashLayer) > 1 && i < len(layers) { - layer := make([][32]byte, len(hashLayer)/2) if !math.IsPowerOf2(uint64(len(hashLayer))) { return nil, nil, errors.Errorf("hash layer is a non power of 2: %d", len(hashLayer)) } - for j := 0; j < len(hashLayer); j += 2 { - chunkBuffer.Write(hashLayer[j][:]) - chunkBuffer.Write(hashLayer[j+1][:]) - hashedChunk := hasher(chunkBuffer.Bytes()) - layer[j/2] = hashedChunk - chunkBuffer.Reset() + if features.Get().EnableVectorizedHTR { + newLayer := make([][32]byte, len(hashLayer)/2) + htr.VectorizedSha256(hashLayer, newLayer) + hashLayer = newLayer + } else { + layer := make([][32]byte, len(hashLayer)/2) + for j := 0; j < len(hashLayer); j += 2 { + chunkBuffer.Write(hashLayer[j][:]) + chunkBuffer.Write(hashLayer[j+1][:]) + hashedChunk := hasher(chunkBuffer.Bytes()) + layer[j/2] = hashedChunk + chunkBuffer.Reset() + } + hashLayer = layer } - hashLayer = layer layers[i] = hashLayer i++ } diff --git a/beacon-chain/state/stateutil/trie_helpers_test.go b/beacon-chain/state/stateutil/trie_helpers_test.go index 24e86ace6a..8b36053345 100644 --- a/beacon-chain/state/stateutil/trie_helpers_test.go +++ b/beacon-chain/state/stateutil/trie_helpers_test.go @@ -4,7 +4,9 @@ import ( "testing" types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/prysm/beacon-chain/state" "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/encoding/bytesutil" @@ -18,15 +20,56 @@ func TestReturnTrieLayer_OK(t *testing.T) { newState, _ := util.DeterministicGenesisState(t, 32) root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) require.NoError(t, err) - blockRts := newState.BlockRoots() - roots := make([][32]byte, 0, len(blockRts)) - for _, rt := range blockRts { - roots = append(roots, bytesutil.ToBytes32(rt)) - } + roots := retrieveBlockRoots(newState) layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) assert.NoError(t, err) newRoot := *layers[len(layers)-1][0] assert.Equal(t, root, newRoot) + + flags := &features.Flags{} + flags.EnableVectorizedHTR = true + reset := features.InitWithReset(flags) + defer reset() + + layers, err = stateutil.ReturnTrieLayer(roots, uint64(len(roots))) + assert.NoError(t, err) + lastRoot := *layers[len(layers)-1][0] + assert.Equal(t, root, lastRoot) +} + +func BenchmarkReturnTrieLayer_NormalAlgorithm(b *testing.B) { + newState, _ := util.DeterministicGenesisState(b, 32) + root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) + require.NoError(b, err) + roots := retrieveBlockRoots(newState) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) + assert.NoError(b, err) + newRoot := *layers[len(layers)-1][0] + assert.Equal(b, root, newRoot) + } +} + +func BenchmarkReturnTrieLayer_VectorizedAlgorithm(b *testing.B) { + flags := &features.Flags{} + flags.EnableVectorizedHTR = true + reset := features.InitWithReset(flags) + defer reset() + + newState, _ := util.DeterministicGenesisState(b, 32) + root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot)) + require.NoError(b, err) + roots := retrieveBlockRoots(newState) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) + assert.NoError(b, err) + newRoot := *layers[len(layers)-1][0] + assert.Equal(b, root, newRoot) + } } func TestReturnTrieLayerVariable_OK(t *testing.T) { @@ -46,15 +89,73 @@ func TestReturnTrieLayerVariable_OK(t *testing.T) { newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators))) require.NoError(t, err) assert.Equal(t, root, newRoot) + + flags := &features.Flags{} + flags.EnableVectorizedHTR = true + reset := features.InitWithReset(flags) + defer reset() + + layers = stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit) + lastRoot := *layers[len(layers)-1][0] + lastRoot, err = stateutil.AddInMixin(lastRoot, uint64(len(validators))) + require.NoError(t, err) + assert.Equal(t, root, lastRoot) + +} + +func BenchmarkReturnTrieLayerVariable_NormalAlgorithm(b *testing.B) { + newState, _ := util.DeterministicGenesisState(b, 16000) + root, err := stateutil.ValidatorRegistryRoot(newState.Validators()) + require.NoError(b, err) + hasher := hash.CustomSHA256Hasher() + validators := newState.Validators() + roots := make([][32]byte, 0, len(validators)) + for _, val := range validators { + rt, err := stateutil.ValidatorRootWithHasher(hasher, val) + require.NoError(b, err) + roots = append(roots, rt) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit) + newRoot := *layers[len(layers)-1][0] + newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators))) + require.NoError(b, err) + assert.Equal(b, root, newRoot) + } +} + +func BenchmarkReturnTrieLayerVariable_VectorizedAlgorithm(b *testing.B) { + flags := &features.Flags{} + flags.EnableVectorizedHTR = true + reset := features.InitWithReset(flags) + defer reset() + + newState, _ := util.DeterministicGenesisState(b, 16000) + root, err := stateutil.ValidatorRegistryRoot(newState.Validators()) + require.NoError(b, err) + hasher := hash.CustomSHA256Hasher() + validators := newState.Validators() + roots := make([][32]byte, 0, len(validators)) + for _, val := range validators { + rt, err := stateutil.ValidatorRootWithHasher(hasher, val) + require.NoError(b, err) + roots = append(roots, rt) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + layers := stateutil.ReturnTrieLayerVariable(roots, params.BeaconConfig().ValidatorRegistryLimit) + newRoot := *layers[len(layers)-1][0] + newRoot, err = stateutil.AddInMixin(newRoot, uint64(len(validators))) + require.NoError(b, err) + assert.Equal(b, root, newRoot) + } } func TestRecomputeFromLayer_FixedSizedArray(t *testing.T) { newState, _ := util.DeterministicGenesisState(t, 32) - blockRts := newState.BlockRoots() - roots := make([][32]byte, 0, len(blockRts)) - for _, rt := range blockRts { - roots = append(roots, bytesutil.ToBytes32(rt)) - } + roots := retrieveBlockRoots(newState) + layers, err := stateutil.ReturnTrieLayer(roots, uint64(len(roots))) require.NoError(t, err) @@ -120,3 +221,12 @@ func TestMerkleizeTrieLeaves_BadHashLayer(t *testing.T) { }) assert.ErrorContains(t, "hash layer is a non power of 2", err) } + +func retrieveBlockRoots(b state.BeaconState) [][32]byte { + blockRts := b.BlockRoots() + roots := make([][32]byte, 0, len(blockRts)) + for _, rt := range blockRts { + roots = append(roots, bytesutil.ToBytes32(rt)) + } + return roots +} diff --git a/beacon-chain/state/stateutil/validator_root.go b/beacon-chain/state/stateutil/validator_root.go index 47c823b4d9..8705b23627 100644 --- a/beacon-chain/state/stateutil/validator_root.go +++ b/beacon-chain/state/stateutil/validator_root.go @@ -4,8 +4,10 @@ import ( "encoding/binary" "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/config/features" fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams" "github.com/prysmaticlabs/prysm/crypto/hash" + "github.com/prysmaticlabs/prysm/crypto/hash/htr" "github.com/prysmaticlabs/prysm/encoding/bytesutil" "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" @@ -14,6 +16,16 @@ import ( // ValidatorRootWithHasher describes a method from which the hash tree root // of a validator is returned. func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32]byte, error) { + fieldRoots, err := ValidatorFieldRoots(hasher, validator) + if err != nil { + return [32]byte{}, err + } + return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots))) +} + +// ValidatorFieldRoots describes a method from which the hash tree root +// of a validator is returned. +func ValidatorFieldRoots(hasher ssz.HashFn, validator *ethpb.Validator) ([][32]byte, error) { var fieldRoots [][32]byte if validator != nil { pubkey := bytesutil.ToBytes48(validator.PublicKey) @@ -40,18 +52,25 @@ func ValidatorRootWithHasher(hasher ssz.HashFn, validator *ethpb.Validator) ([32 binary.LittleEndian.PutUint64(withdrawalBuf[:8], uint64(validator.WithdrawableEpoch)) // Public key. - pubKeyChunks, err := ssz.Pack([][]byte{pubkey[:]}) + pubKeyChunks, err := ssz.PackByChunk([][]byte{pubkey[:]}) if err != nil { - return [32]byte{}, err + return [][32]byte{}, err } - pubKeyRoot, err := ssz.BitwiseMerkleize(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks))) - if err != nil { - return [32]byte{}, err + var pubKeyRoot [32]byte + if features.Get().EnableVectorizedHTR { + outputChunk := make([][32]byte, 1) + htr.VectorizedSha256(pubKeyChunks, outputChunk) + pubKeyRoot = outputChunk[0] + } else { + pubKeyRoot, err = ssz.BitwiseMerkleizeArrays(hasher, pubKeyChunks, uint64(len(pubKeyChunks)), uint64(len(pubKeyChunks))) + if err != nil { + return [][32]byte{}, err + } } fieldRoots = [][32]byte{pubKeyRoot, withdrawCreds, effectiveBalanceBuf, slashBuf, activationEligibilityBuf, activationBuf, exitBuf, withdrawalBuf} } - return ssz.BitwiseMerkleizeArrays(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots))) + return fieldRoots, nil } // Uint64ListRootWithRegistryLimit computes the HashTreeRoot Merkleization of diff --git a/beacon-chain/state/v1/state_trie_test.go b/beacon-chain/state/v1/state_trie_test.go index 73c1beecf7..6455cc077c 100644 --- a/beacon-chain/state/v1/state_trie_test.go +++ b/beacon-chain/state/v1/state_trie_test.go @@ -177,6 +177,24 @@ func TestBeaconState_HashTreeRoot(t *testing.T) { } } +func BenchmarkBeaconState(b *testing.B) { + testState, _ := util.DeterministicGenesisState(b, 16000) + pbState, err := v1.ProtobufBeaconState(testState.InnerStateUnsafe()) + require.NoError(b, err) + + b.Run("Vectorized SHA256", func(b *testing.B) { + st, err := v1.InitializeFromProtoUnsafe(pbState) + require.NoError(b, err) + _, err = st.HashTreeRoot(context.Background()) + assert.NoError(b, err) + }) + + b.Run("Current SHA256", func(b *testing.B) { + _, err := pbState.HashTreeRoot() + require.NoError(b, err) + }) +} + func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) { testState, _ := util.DeterministicGenesisState(t, 64) diff --git a/beacon-chain/sync/BUILD.bazel b/beacon-chain/sync/BUILD.bazel index 88a9a0ff56..2e42694ce7 100644 --- a/beacon-chain/sync/BUILD.bazel +++ b/beacon-chain/sync/BUILD.bazel @@ -86,7 +86,7 @@ go_library( "//crypto/bls:go_default_library", "//crypto/rand:go_default_library", "//encoding/bytesutil:go_default_library", - "//encoding/ssz:go_default_library", + "//encoding/ssz/equality:go_default_library", "//monitoring/tracing:go_default_library", "//network/forks:go_default_library", "//proto/prysm/v1alpha1:go_default_library", @@ -195,7 +195,7 @@ go_test( "//crypto/bls:go_default_library", "//crypto/rand:go_default_library", "//encoding/bytesutil:go_default_library", - "//encoding/ssz:go_default_library", + "//encoding/ssz/equality:go_default_library", "//network/forks:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1/attestation:go_default_library", diff --git a/beacon-chain/sync/pending_blocks_queue.go b/beacon-chain/sync/pending_blocks_queue.go index 0daf892d0b..a0bf2ab1e6 100644 --- a/beacon-chain/sync/pending_blocks_queue.go +++ b/beacon-chain/sync/pending_blocks_queue.go @@ -15,7 +15,7 @@ import ( "github.com/prysmaticlabs/prysm/config/params" "github.com/prysmaticlabs/prysm/crypto/rand" "github.com/prysmaticlabs/prysm/encoding/bytesutil" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" "github.com/prysmaticlabs/prysm/monitoring/tracing" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/block" "github.com/prysmaticlabs/prysm/time/slots" @@ -328,7 +328,7 @@ func (s *Service) deleteBlockFromPendingQueue(slot types.Slot, b block.SignedBea newBlks := make([]block.SignedBeaconBlock, 0, len(blks)) for _, blk := range blks { - if ssz.DeepEqual(blk.Proto(), b.Proto()) { + if equality.DeepEqual(blk.Proto(), b.Proto()) { continue } newBlks = append(newBlks, blk) diff --git a/beacon-chain/sync/rpc_metadata_test.go b/beacon-chain/sync/rpc_metadata_test.go index 105a5016e4..5eea65c915 100644 --- a/beacon-chain/sync/rpc_metadata_test.go +++ b/beacon-chain/sync/rpc_metadata_test.go @@ -17,7 +17,7 @@ import ( "github.com/prysmaticlabs/prysm/beacon-chain/p2p" p2ptest "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing" "github.com/prysmaticlabs/prysm/config/params" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/metadata" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper" @@ -125,7 +125,7 @@ func TestMetadataRPCHandler_SendsMetadata(t *testing.T) { metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID()) assert.NoError(t, err) - if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { + if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { t.Fatalf("MetadataV0 unequal, received %v but wanted %v", metadata, p2.LocalMetadata) } @@ -213,7 +213,7 @@ func TestMetadataRPCHandler_SendsMetadataAltair(t *testing.T) { metadata, err := r.sendMetaDataRequest(context.Background(), p2.BHost.ID()) assert.NoError(t, err) - if !ssz.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { + if !equality.DeepEqual(metadata.InnerObject(), p2.LocalMetadata.InnerObject()) { t.Fatalf("MetadataV1 unequal, received %v but wanted %v", metadata, p2.LocalMetadata) } diff --git a/config/features/config.go b/config/features/config.go index 6949dee234..4147c55f56 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -74,7 +74,8 @@ type Flags struct { CorrectlyInsertOrphanedAtts bool CorrectlyPruneCanonicalAtts bool - EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct. + EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct. + EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines. // KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have // changed on disk. This feature is for advanced use cases only. @@ -222,6 +223,10 @@ func ConfigureBeaconChain(ctx *cli.Context) { logEnabled(enableNativeState) cfg.EnableNativeState = true } + if ctx.Bool(enableVecHTR.Name) { + logEnabled(enableVecHTR) + cfg.EnableVectorizedHTR = true + } Init(cfg) } diff --git a/config/features/flags.go b/config/features/flags.go index 22f6bf4dbb..d487f279a0 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -134,11 +134,16 @@ var ( Name: "enable-native-state", Usage: "Enables representing the beacon state as a pure Go struct.", } + enableVecHTR = &cli.BoolFlag{ + Name: "enable-vectorized-htr", + Usage: "Enables new go sha256 library which utilizes optimized routines for merkle trees", + } ) // devModeFlags holds list of flags that are set when development mode is on. var devModeFlags = []cli.Flag{ enablePeerScorer, + enableVecHTR, } // ValidatorFlags contains a list of all the feature flags that apply to the validator client. @@ -183,6 +188,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{ disableBatchGossipVerification, disableBalanceTrieComputation, enableNativeState, + enableVecHTR, }...) // E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E. diff --git a/crypto/hash/htr/BUILD.bazel b/crypto/hash/htr/BUILD.bazel new file mode 100644 index 0000000000..c8b78df350 --- /dev/null +++ b/crypto/hash/htr/BUILD.bazel @@ -0,0 +1,9 @@ +load("@prysm//tools/go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["hashtree.go"], + importpath = "github.com/prysmaticlabs/prysm/crypto/hash/htr", + visibility = ["//visibility:public"], + deps = ["@com_github_prysmaticlabs_gohashtree//:go_default_library"], +) diff --git a/crypto/hash/htr/hashtree.go b/crypto/hash/htr/hashtree.go new file mode 100644 index 0000000000..da50f326c7 --- /dev/null +++ b/crypto/hash/htr/hashtree.go @@ -0,0 +1,17 @@ +package htr + +import ( + "github.com/prysmaticlabs/gohashtree" +) + +// VectorizedSha256 takes a list of roots and hashes them using CPU +// specific vector instructions. Depending on host machine's specific +// hardware configuration, using this routine can lead to a significant +// performance improvement compared to the default method of hashing +// lists. +func VectorizedSha256(inputList [][32]byte, outputList [][32]byte) { + err := gohashtree.Hash(outputList, inputList) + if err != nil { + panic(err) + } +} diff --git a/deps.bzl b/deps.bzl index d8774dbf45..9c7b7552cc 100644 --- a/deps.bzl +++ b/deps.bzl @@ -2982,6 +2982,13 @@ def prysm_deps(): sum = "h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=", version = "v0.0.0-20210809151128-385d8c5e3fb7", ) + go_repository( + name = "com_github_prysmaticlabs_gohashtree", + importpath = "github.com/prysmaticlabs/gohashtree", + sum = "h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w=", + version = "v0.0.0-20220208111633-0606f58df32f", + ) + go_repository( name = "com_github_prysmaticlabs_prombbolt", importpath = "github.com/prysmaticlabs/prombbolt", diff --git a/encoding/ssz/BUILD.bazel b/encoding/ssz/BUILD.bazel index 78378fd63c..fe8fe8efdb 100644 --- a/encoding/ssz/BUILD.bazel +++ b/encoding/ssz/BUILD.bazel @@ -3,7 +3,6 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", srcs = [ - "deep_equal.go", "hashers.go", "helpers.go", "htrutils.go", @@ -12,16 +11,16 @@ go_library( importpath = "github.com/prysmaticlabs/prysm/encoding/ssz", visibility = ["//visibility:public"], deps = [ + "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//container/trie:go_default_library", "//crypto/hash:go_default_library", + "//crypto/hash/htr:go_default_library", "//encoding/bytesutil:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "@com_github_minio_sha256_simd//:go_default_library", "@com_github_pkg_errors//:go_default_library", - "@com_github_prysmaticlabs_eth2_types//:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", - "@org_golang_google_protobuf//proto:go_default_library", ], ) @@ -29,7 +28,6 @@ go_test( name = "go_default_test", size = "small", srcs = [ - "deep_equal_test.go", "hashers_test.go", "helpers_test.go", "htrutils_test.go", diff --git a/encoding/ssz/equality/BUILD.bazel b/encoding/ssz/equality/BUILD.bazel new file mode 100644 index 0000000000..52b39fb4fa --- /dev/null +++ b/encoding/ssz/equality/BUILD.bazel @@ -0,0 +1,22 @@ +load("@prysm//tools/go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["deep_equal.go"], + importpath = "github.com/prysmaticlabs/prysm/encoding/ssz/equality", + visibility = ["//visibility:public"], + deps = [ + "@com_github_prysmaticlabs_eth2_types//:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = ["deep_equal_test.go"], + deps = [ + ":go_default_library", + "//proto/prysm/v1alpha1:go_default_library", + "//testing/assert:go_default_library", + ], +) diff --git a/encoding/ssz/deep_equal.go b/encoding/ssz/equality/deep_equal.go similarity index 99% rename from encoding/ssz/deep_equal.go rename to encoding/ssz/equality/deep_equal.go index aa47acc9f7..b779969d19 100644 --- a/encoding/ssz/deep_equal.go +++ b/encoding/ssz/equality/deep_equal.go @@ -1,4 +1,4 @@ -package ssz +package equality import ( "reflect" diff --git a/encoding/ssz/deep_equal_test.go b/encoding/ssz/equality/deep_equal_test.go similarity index 54% rename from encoding/ssz/deep_equal_test.go rename to encoding/ssz/equality/deep_equal_test.go index c4699f7c76..ef2c9e2e3e 100644 --- a/encoding/ssz/deep_equal_test.go +++ b/encoding/ssz/equality/deep_equal_test.go @@ -1,34 +1,34 @@ -package ssz_test +package equality_test import ( "testing" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/testing/assert" ) func TestDeepEqualBasicTypes(t *testing.T) { - assert.Equal(t, true, ssz.DeepEqual(true, true)) - assert.Equal(t, false, ssz.DeepEqual(true, false)) + assert.Equal(t, true, equality.DeepEqual(true, true)) + assert.Equal(t, false, equality.DeepEqual(true, false)) - assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222))) - assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111))) + assert.Equal(t, true, equality.DeepEqual(byte(222), byte(222))) + assert.Equal(t, false, equality.DeepEqual(byte(222), byte(111))) - assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890))) - assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210))) + assert.Equal(t, true, equality.DeepEqual(uint64(1234567890), uint64(1234567890))) + assert.Equal(t, false, equality.DeepEqual(uint64(1234567890), uint64(987653210))) - assert.Equal(t, true, ssz.DeepEqual("hello", "hello")) - assert.Equal(t, false, ssz.DeepEqual("hello", "world")) + assert.Equal(t, true, equality.DeepEqual("hello", "hello")) + assert.Equal(t, false, equality.DeepEqual("hello", "world")) - assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3})) - assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4})) + assert.Equal(t, true, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3})) + assert.Equal(t, false, equality.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4})) var nilSlice1, nilSlice2 []byte - assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2)) - assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{})) - assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3})) - assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4})) + assert.Equal(t, true, equality.DeepEqual(nilSlice1, nilSlice2)) + assert.Equal(t, true, equality.DeepEqual(nilSlice1, []byte{})) + assert.Equal(t, true, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3})) + assert.Equal(t, false, equality.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4})) } func TestDeepEqualStructs(t *testing.T) { @@ -39,8 +39,8 @@ func TestDeepEqualStructs(t *testing.T) { store1 := Store{uint64(1234), nil} store2 := Store{uint64(1234), []byte{}} store3 := Store{uint64(4321), []byte{}} - assert.Equal(t, true, ssz.DeepEqual(store1, store2)) - assert.Equal(t, false, ssz.DeepEqual(store1, store3)) + assert.Equal(t, true, equality.DeepEqual(store1, store2)) + assert.Equal(t, false, equality.DeepEqual(store1, store3)) } func TestDeepEqualStructs_Unexported(t *testing.T) { @@ -53,14 +53,14 @@ func TestDeepEqualStructs_Unexported(t *testing.T) { store2 := Store{uint64(1234), []byte{}, "hi there"} store3 := Store{uint64(4321), []byte{}, "wow"} store4 := Store{uint64(4321), []byte{}, "bow wow"} - assert.Equal(t, true, ssz.DeepEqual(store1, store2)) - assert.Equal(t, false, ssz.DeepEqual(store1, store3)) - assert.Equal(t, false, ssz.DeepEqual(store3, store4)) + assert.Equal(t, true, equality.DeepEqual(store1, store2)) + assert.Equal(t, false, equality.DeepEqual(store1, store3)) + assert.Equal(t, false, equality.DeepEqual(store3, store4)) } func TestDeepEqualProto(t *testing.T) { var fork1, fork2 *ethpb.Fork - assert.Equal(t, true, ssz.DeepEqual(fork1, fork2)) + assert.Equal(t, true, equality.DeepEqual(fork1, fork2)) fork1 = ðpb.Fork{ PreviousVersion: []byte{123}, @@ -72,8 +72,8 @@ func TestDeepEqualProto(t *testing.T) { CurrentVersion: []byte{125}, Epoch: 1234567890, } - assert.Equal(t, true, ssz.DeepEqual(fork1, fork1)) - assert.Equal(t, false, ssz.DeepEqual(fork1, fork2)) + assert.Equal(t, true, equality.DeepEqual(fork1, fork1)) + assert.Equal(t, false, equality.DeepEqual(fork1, fork2)) checkpoint1 := ðpb.Checkpoint{ Epoch: 1234567890, @@ -83,7 +83,7 @@ func TestDeepEqualProto(t *testing.T) { Epoch: 1234567890, Root: nil, } - assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2)) + assert.Equal(t, true, equality.DeepEqual(checkpoint1, checkpoint2)) } func Test_IsProto(t *testing.T) { @@ -125,7 +125,7 @@ func Test_IsProto(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := ssz.IsProto(tt.item); got != tt.want { + if got := equality.IsProto(tt.item); got != tt.want { t.Errorf("isProtoSlice() = %v, want %v", got, tt.want) } }) diff --git a/encoding/ssz/helpers.go b/encoding/ssz/helpers.go index 7532a7a89e..83eba70c5d 100644 --- a/encoding/ssz/helpers.go +++ b/encoding/ssz/helpers.go @@ -8,6 +8,7 @@ import ( "github.com/minio/sha256-simd" "github.com/pkg/errors" "github.com/prysmaticlabs/go-bitfield" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/encoding/bytesutil" ) @@ -50,6 +51,9 @@ func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32] if count > limit { return [32]byte{}, errors.New("merkleizing list that is too large, over limit") } + if features.Get().EnableVectorizedHTR { + return MerkleizeList(chunks, limit), nil + } hashFn := NewHasherFunc(hasher) leafIndexer := func(i uint64) []byte { return chunks[i] @@ -62,6 +66,9 @@ func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint6 if count > limit { return [32]byte{}, errors.New("merkleizing list that is too large, over limit") } + if features.Get().EnableVectorizedHTR { + return MerkleizeVector(chunks, limit), nil + } hashFn := NewHasherFunc(hasher) leafIndexer := func(i uint64) []byte { return chunks[i][:] diff --git a/encoding/ssz/merkleize.go b/encoding/ssz/merkleize.go index a9f542f1b7..2118004450 100644 --- a/encoding/ssz/merkleize.go +++ b/encoding/ssz/merkleize.go @@ -2,6 +2,7 @@ package ssz import ( "github.com/prysmaticlabs/prysm/container/trie" + "github.com/prysmaticlabs/prysm/crypto/hash/htr" ) // Merkleize.go is mostly a directly copy of the same filename from @@ -196,3 +197,40 @@ func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []by return } + +// MerkleizeVector uses our optimized routine to hash a list of 32-byte +// elements. +func MerkleizeVector(elements [][32]byte, length uint64) [32]byte { + depth := Depth(length) + // Return zerohash at depth + if len(elements) == 0 { + return trie.ZeroHashes[depth] + } + for i := 0; i < int(depth); i++ { + layerLen := len(elements) + oddNodeLength := layerLen%2 == 1 + if oddNodeLength { + zerohash := trie.ZeroHashes[i] + elements = append(elements, zerohash) + } + outputLen := len(elements) / 2 + htr.VectorizedSha256(elements, elements) + elements = elements[:outputLen] + } + return elements[0] +} + +// MerkleizeList uses our optimized routine to hash a 2d-list of +// elements. +func MerkleizeList(elements [][]byte, length uint64) [32]byte { + depth := Depth(length) + // Return zerohash at depth + if len(elements) == 0 { + return trie.ZeroHashes[depth] + } + newElems := make([][32]byte, len(elements)) + for i := range elements { + copy(newElems[i][:], elements[i]) + } + return MerkleizeVector(newElems, length) +} diff --git a/go.mod b/go.mod index d505c733f8..82d776d4f5 100644 --- a/go.mod +++ b/go.mod @@ -253,6 +253,7 @@ require ( github.com/holiman/uint256 v1.2.0 github.com/peterh/liner v1.2.0 // indirect github.com/prometheus/tsdb v0.10.0 // indirect + github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 // indirect google.golang.org/api v0.34.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index d247626763..ba7a0145da 100644 --- a/go.sum +++ b/go.sum @@ -1123,6 +1123,8 @@ github.com/prysmaticlabs/fastssz v0.0.0-20220110145812-fafb696cae88/go.mod h1:AS github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s= github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw= github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4= +github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f h1:qVXMFNUGWbPpP7986YPsmL84A54oY7b/N+SOIBS3i5w= +github.com/prysmaticlabs/gohashtree v0.0.0-20220208111633-0606f58df32f/go.mod h1:4pWaT30XoEx1j8KNJf3TV+E3mQkaufn7mf+jRNb/Fuk= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1 h1:xcu59yYL6AWWTl6jtejBfE0y8uF35fArCBeZjRlvJss= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210702154020-550e1cd83ec1/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24= github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU= diff --git a/proto/prysm/v1alpha1/attestation/aggregation/attestations/BUILD.bazel b/proto/prysm/v1alpha1/attestation/aggregation/attestations/BUILD.bazel index 2002ef2924..c581cf30f4 100644 --- a/proto/prysm/v1alpha1/attestation/aggregation/attestations/BUILD.bazel +++ b/proto/prysm/v1alpha1/attestation/aggregation/attestations/BUILD.bazel @@ -28,7 +28,7 @@ go_test( deps = [ "//config/params:go_default_library", "//crypto/bls:go_default_library", - "//encoding/ssz:go_default_library", + "//encoding/ssz/equality:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1/attestation/aggregation:go_default_library", "//proto/prysm/v1alpha1/attestation/aggregation/testing:go_default_library", diff --git a/proto/prysm/v1alpha1/attestation/aggregation/attestations/attestations_test.go b/proto/prysm/v1alpha1/attestation/aggregation/attestations/attestations_test.go index 48256391d7..cdbaf71ba5 100644 --- a/proto/prysm/v1alpha1/attestation/aggregation/attestations/attestations_test.go +++ b/proto/prysm/v1alpha1/attestation/aggregation/attestations/attestations_test.go @@ -8,7 +8,7 @@ import ( "github.com/prysmaticlabs/go-bitfield" "github.com/prysmaticlabs/prysm/config/params" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation" aggtesting "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/attestation/aggregation/testing" @@ -48,7 +48,7 @@ func TestAggregateAttestations_AggregatePair(t *testing.T) { for _, tt := range tests { got, err := AggregatePair(tt.a1, tt.a2) require.NoError(t, err) - require.Equal(t, true, ssz.DeepEqual(got, tt.want)) + require.Equal(t, true, equality.DeepEqual(got, tt.want)) } } diff --git a/testing/assertions/BUILD.bazel b/testing/assertions/BUILD.bazel index c64a7f9920..f17c264cd4 100644 --- a/testing/assertions/BUILD.bazel +++ b/testing/assertions/BUILD.bazel @@ -6,7 +6,7 @@ go_library( importpath = "github.com/prysmaticlabs/prysm/testing/assertions", visibility = ["//visibility:public"], deps = [ - "//encoding/ssz:go_default_library", + "//encoding/ssz/equality:go_default_library", "@com_github_d4l3k_messagediff//:go_default_library", "@com_github_sirupsen_logrus//hooks/test:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", diff --git a/testing/assertions/assertions.go b/testing/assertions/assertions.go index e2ec9246f3..86383ac80e 100644 --- a/testing/assertions/assertions.go +++ b/testing/assertions/assertions.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/d4l3k/messagediff" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" "github.com/sirupsen/logrus/hooks/test" "google.golang.org/protobuf/proto" ) @@ -61,7 +61,7 @@ func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg // DeepSSZEqual compares values using ssz.DeepEqual. func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) { - if !ssz.DeepEqual(expected, actual) { + if !equality.DeepEqual(expected, actual) { errMsg := parseMsg("Values are not equal", msg...) _, file, line, _ := runtime.Caller(2) diff, _ := messagediff.PrettyDiff(expected, actual) @@ -71,7 +71,7 @@ func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg // DeepNotSSZEqual compares values using ssz.DeepEqual. func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) { - if ssz.DeepEqual(expected, actual) { + if equality.DeepEqual(expected, actual) { errMsg := parseMsg("Values are equal", msg...) _, file, line, _ := runtime.Caller(2) loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual) diff --git a/tools/pcli/BUILD.bazel b/tools/pcli/BUILD.bazel index 09a70d3561..4e197a6c25 100644 --- a/tools/pcli/BUILD.bazel +++ b/tools/pcli/BUILD.bazel @@ -12,7 +12,7 @@ go_library( deps = [ "//beacon-chain/core/transition:go_default_library", "//beacon-chain/state/v1:go_default_library", - "//encoding/ssz:go_default_library", + "//encoding/ssz/equality:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1/wrapper:go_default_library", "//runtime/version:go_default_library", diff --git a/tools/pcli/main.go b/tools/pcli/main.go index e728bd4277..db6926a244 100644 --- a/tools/pcli/main.go +++ b/tools/pcli/main.go @@ -13,7 +13,7 @@ import ( "github.com/kr/pretty" "github.com/prysmaticlabs/prysm/beacon-chain/core/transition" v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1" - "github.com/prysmaticlabs/prysm/encoding/ssz" + "github.com/prysmaticlabs/prysm/encoding/ssz/equality" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper" "github.com/prysmaticlabs/prysm/runtime/version" @@ -191,7 +191,7 @@ func main() { if err := dataFetcher(expectedPostStatePath, expectedState); err != nil { log.Fatal(err) } - if !ssz.DeepEqual(expectedState, postState.InnerStateUnsafe()) { + if !equality.DeepEqual(expectedState, postState.InnerStateUnsafe()) { diff, _ := messagediff.PrettyDiff(expectedState, postState.InnerStateUnsafe()) log.Errorf("Derived state differs from provided post state: %s", diff) }