diff --git a/shared/trieutil/BUILD.bazel b/shared/trieutil/BUILD.bazel index 08b3e4751b..27eda46945 100644 --- a/shared/trieutil/BUILD.bazel +++ b/shared/trieutil/BUILD.bazel @@ -12,6 +12,7 @@ go_library( deps = [ "//shared/bytesutil:go_default_library", "//shared/hashutil:go_default_library", + "//shared/params:go_default_library", ], ) diff --git a/shared/trieutil/merkle_trie.go b/shared/trieutil/merkle_trie.go index 63843de740..6712102c3c 100644 --- a/shared/trieutil/merkle_trie.go +++ b/shared/trieutil/merkle_trie.go @@ -1,5 +1,10 @@ package trieutil +import ( + "github.com/prysmaticlabs/prysm/shared/hashutil" + "github.com/prysmaticlabs/prysm/shared/params" +) + // NextPowerOf2 returns the next power of 2 >= the input // // Spec pseudocode definition: @@ -35,3 +40,40 @@ func PrevPowerOf2(n int) int { } return 2 * PrevPowerOf2(n/2) } + +// MerkleTree returns all the nodes in a merkle tree from inputting merkle leaves. +// +// Spec pseudocode definition: +// def merkle_tree(leaves: Sequence[Hash]) -> Sequence[Hash]: +// padded_length = get_next_power_of_two(len(leaves)) +// o = [Hash()] * padded_length + list(leaves) + [Hash()] * (padded_length - len(leaves)) +// for i in range(padded_length - 1, 0, -1): +// o[i] = hash(o[i * 2] + o[i * 2 + 1]) +// return o +func MerkleTree(leaves [][]byte) [][]byte { + paddedLength := NextPowerOf2(len(leaves)) + parents := make([][]byte, paddedLength) + paddedLeaves := make([][]byte, paddedLength-len(leaves)) + + for i := 0; i < len(parents); i++ { + parents[i] = params.BeaconConfig().ZeroHash[:] + } + for i := 0; i < len(paddedLeaves); i++ { + paddedLeaves[i] = params.BeaconConfig().ZeroHash[:] + } + + merkleTree := make([][]byte, len(parents)+len(leaves)+len(paddedLeaves)) + copy(merkleTree, parents) + l := len(parents) + copy(merkleTree[l:], leaves) + l += len(paddedLeaves) + copy(merkleTree[l:], paddedLeaves) + + for i := len(paddedLeaves) - 1; i > 0; i-- { + a := append(merkleTree[2*i], merkleTree[2*i+1]...) + b := hashutil.Hash(a) + merkleTree[i] = b[:] + } + + return merkleTree +} diff --git a/shared/trieutil/merkle_trie_test.go b/shared/trieutil/merkle_trie_test.go index c7ff663b63..da4bcf8438 100644 --- a/shared/trieutil/merkle_trie_test.go +++ b/shared/trieutil/merkle_trie_test.go @@ -1,6 +1,7 @@ package trieutil import ( + "math/rand" "testing" ) @@ -43,3 +44,33 @@ func TestPrevPowerOf2(t *testing.T) { } } } + +func TestMerkleTreeLength(t *testing.T) { + tests := []struct { + leaves [][]byte + length int + }{ + {[][]byte{{'A'}, {'B'}, {'C'}}, 8}, + {[][]byte{{'A'}, {'B'}, {'C'}, {'D'}}, 8}, + {[][]byte{{'A'}, {'B'}, {'C'}, {'D'}, {'E'}}, 16}, + } + for _, tt := range tests { + if got := MerkleTree(tt.leaves); len(got) != tt.length { + t.Errorf("len(MerkleTree()) = %v, want %v", got, tt.length) + } + } +} + +func BenchmarkMerkleTree_Generate(b *testing.B) { + leaves := make([][]byte, 1<<20) + for i := 0; i < len(leaves); i++ { + b := make([]byte, 32) + rand.Read(b) + leaves[i] = b + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + MerkleTree(leaves) + } +}