mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-10 07:58:22 -05:00
Merkle tree implementation (#3572)
This commit is contained in:
@@ -12,6 +12,7 @@ go_library(
|
||||
deps = [
|
||||
"//shared/bytesutil:go_default_library",
|
||||
"//shared/hashutil:go_default_library",
|
||||
"//shared/params:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
package trieutil
|
||||
|
||||
import (
|
||||
"github.com/prysmaticlabs/prysm/shared/hashutil"
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
)
|
||||
|
||||
// NextPowerOf2 returns the next power of 2 >= the input
|
||||
//
|
||||
// Spec pseudocode definition:
|
||||
@@ -35,3 +40,40 @@ func PrevPowerOf2(n int) int {
|
||||
}
|
||||
return 2 * PrevPowerOf2(n/2)
|
||||
}
|
||||
|
||||
// MerkleTree returns all the nodes in a merkle tree from inputting merkle leaves.
|
||||
//
|
||||
// Spec pseudocode definition:
|
||||
// def merkle_tree(leaves: Sequence[Hash]) -> Sequence[Hash]:
|
||||
// padded_length = get_next_power_of_two(len(leaves))
|
||||
// o = [Hash()] * padded_length + list(leaves) + [Hash()] * (padded_length - len(leaves))
|
||||
// for i in range(padded_length - 1, 0, -1):
|
||||
// o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
||||
// return o
|
||||
func MerkleTree(leaves [][]byte) [][]byte {
|
||||
paddedLength := NextPowerOf2(len(leaves))
|
||||
parents := make([][]byte, paddedLength)
|
||||
paddedLeaves := make([][]byte, paddedLength-len(leaves))
|
||||
|
||||
for i := 0; i < len(parents); i++ {
|
||||
parents[i] = params.BeaconConfig().ZeroHash[:]
|
||||
}
|
||||
for i := 0; i < len(paddedLeaves); i++ {
|
||||
paddedLeaves[i] = params.BeaconConfig().ZeroHash[:]
|
||||
}
|
||||
|
||||
merkleTree := make([][]byte, len(parents)+len(leaves)+len(paddedLeaves))
|
||||
copy(merkleTree, parents)
|
||||
l := len(parents)
|
||||
copy(merkleTree[l:], leaves)
|
||||
l += len(paddedLeaves)
|
||||
copy(merkleTree[l:], paddedLeaves)
|
||||
|
||||
for i := len(paddedLeaves) - 1; i > 0; i-- {
|
||||
a := append(merkleTree[2*i], merkleTree[2*i+1]...)
|
||||
b := hashutil.Hash(a)
|
||||
merkleTree[i] = b[:]
|
||||
}
|
||||
|
||||
return merkleTree
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package trieutil
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -43,3 +44,33 @@ func TestPrevPowerOf2(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMerkleTreeLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
leaves [][]byte
|
||||
length int
|
||||
}{
|
||||
{[][]byte{{'A'}, {'B'}, {'C'}}, 8},
|
||||
{[][]byte{{'A'}, {'B'}, {'C'}, {'D'}}, 8},
|
||||
{[][]byte{{'A'}, {'B'}, {'C'}, {'D'}, {'E'}}, 16},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := MerkleTree(tt.leaves); len(got) != tt.length {
|
||||
t.Errorf("len(MerkleTree()) = %v, want %v", got, tt.length)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMerkleTree_Generate(b *testing.B) {
|
||||
leaves := make([][]byte, 1<<20)
|
||||
for i := 0; i < len(leaves); i++ {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
leaves[i] = b
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
MerkleTree(leaves)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user