Add Encoding SSZ Package (#9630)

* ssz package

* compile

* htrutils

* rem pkg doc

* fix cloners_test.go

* fix circular dep/build issues

Co-authored-by: prestonvanloon <preston@prysmaticlabs.com>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Raul Jordan
2021-09-21 10:02:48 -05:00
committed by GitHub
parent b943f7bce5
commit 45bfd82c88
53 changed files with 361 additions and 377 deletions

46
encoding/ssz/BUILD.bazel Normal file
View File

@@ -0,0 +1,46 @@
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"deep_equal.go",
"hashers.go",
"helpers.go",
"htrutils.go",
"merkleize.go",
],
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
visibility = ["//visibility:public"],
deps = [
"//container/trie:go_default_library",
"//crypto/hash:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//shared/bytesutil:go_default_library",
"//shared/params:go_default_library",
"@com_github_minio_sha256_simd//:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_test",
size = "small",
srcs = [
"deep_equal_test.go",
"hashers_test.go",
"helpers_test.go",
"htrutils_test.go",
"merkleize_test.go",
],
deps = [
":go_default_library",
"//crypto/hash:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//shared/testutil/assert:go_default_library",
"//shared/testutil/require:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
],
)

323
encoding/ssz/deep_equal.go Normal file
View File

@@ -0,0 +1,323 @@
package ssz
import (
"reflect"
"unsafe"
types "github.com/prysmaticlabs/eth2-types"
"google.golang.org/protobuf/proto"
)
// During deepValueEqual, must keep track of checks that are
// in progress. The comparison algorithm assumes that all
// checks in progress are true when it reencounters them.
// Visited comparisons are stored in a map indexed by visit.
type visit struct {
a1 unsafe.Pointer /* #nosec G103 */
a2 unsafe.Pointer /* #nosec G103 */
typ reflect.Type
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file extends Go's reflect.DeepEqual function into a ssz.DeepEqual
// function that is compliant with the supported types of ssz and its
// intricacies when determining equality of empty values.
//
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
// We want to avoid putting more in the visited map than we need to.
// For any possible reference cycle that might be encountered,
// hard(t) needs to return true for at least one of the types in the cycle.
hard := func(k reflect.Kind) bool {
switch k {
case reflect.Slice, reflect.Ptr, reflect.Interface:
return true
}
return false
}
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
addr1 := unsafe.Pointer(v1.UnsafeAddr()) /* #nosec G103 */
addr2 := unsafe.Pointer(v2.UnsafeAddr()) /* #nosec G103 */
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
// Assumes non-moving garbage collector.
addr1, addr2 = addr2, addr1
}
// Short circuit if references are already seen.
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}
// Remember for later.
visited[v] = true
}
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() && v2.Len() == 0 {
return true
}
if v1.Len() == 0 && v2.IsNil() {
return true
}
if v1.IsNil() && v2.IsNil() {
return true
}
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() {
return true
}
for i := 0; i < v1.Len(); i++ {
if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Ptr:
if v1.Pointer() == v2.Pointer() {
return true
}
return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
return false
}
}
return true
default:
return deepValueBaseTypeEqual(v1, v2)
}
}
func deepValueEqualExportedOnly(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
// We want to avoid putting more in the visited map than we need to.
// For any possible reference cycle that might be encountered,
// hard(t) needs to return true for at least one of the types in the cycle.
hard := func(k reflect.Kind) bool {
switch k {
case reflect.Slice, reflect.Ptr, reflect.Interface:
return true
}
return false
}
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
addr1 := unsafe.Pointer(v1.UnsafeAddr()) /* #nosec G103 */
addr2 := unsafe.Pointer(v2.UnsafeAddr()) /* #nosec G103 */
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
// Assumes non-moving garbage collector.
addr1, addr2 = addr2, addr1
}
// Short circuit if references are already seen.
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true
}
// Remember for later.
visited[v] = true
}
switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i++ {
if !deepValueEqualExportedOnly(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() && v2.Len() == 0 {
return true
}
if v1.Len() == 0 && v2.IsNil() {
return true
}
if v1.IsNil() && v2.IsNil() {
return true
}
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() {
return true
}
for i := 0; i < v1.Len(); i++ {
if !deepValueEqualExportedOnly(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return deepValueEqualExportedOnly(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Ptr:
if v1.Pointer() == v2.Pointer() {
return true
}
return deepValueEqualExportedOnly(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Struct:
for i, n := 0, v1.NumField(); i < n; i++ {
v1Field := v1.Field(i)
v2Field := v2.Field(i)
if !v1Field.CanInterface() || !v2Field.CanInterface() {
// Continue for unexported fields, since they cannot be read anyways.
continue
}
if !deepValueEqualExportedOnly(v1Field, v2Field, visited, depth+1) {
return false
}
}
return true
default:
return deepValueBaseTypeEqual(v1, v2)
}
}
func deepValueBaseTypeEqual(v1, v2 reflect.Value) bool {
switch v1.Kind() {
case reflect.String:
return v1.String() == v2.String()
case reflect.Uint64:
switch v1.Type().Name() {
case "Epoch":
return v1.Interface().(types.Epoch) == v2.Interface().(types.Epoch)
case "Slot":
return v1.Interface().(types.Slot) == v2.Interface().(types.Slot)
case "ValidatorIndex":
return v1.Interface().(types.ValidatorIndex) == v2.Interface().(types.ValidatorIndex)
case "CommitteeIndex":
return v1.Interface().(types.CommitteeIndex) == v2.Interface().(types.CommitteeIndex)
}
return v1.Interface().(uint64) == v2.Interface().(uint64)
case reflect.Uint32:
return v1.Interface().(uint32) == v2.Interface().(uint32)
case reflect.Int32:
return v1.Interface().(int32) == v2.Interface().(int32)
case reflect.Uint16:
return v1.Interface().(uint16) == v2.Interface().(uint16)
case reflect.Uint8:
return v1.Interface().(uint8) == v2.Interface().(uint8)
case reflect.Bool:
return v1.Interface().(bool) == v2.Interface().(bool)
default:
return false
}
}
// DeepEqual reports whether two SSZ-able values x and y are ``deeply equal,'' defined as follows:
// Two values of identical type are deeply equal if one of the following cases applies:
//
// Values of distinct types are never deeply equal.
//
// Array values are deeply equal when their corresponding elements are deeply equal.
//
// Struct values are deeply equal if their corresponding fields,
// both exported and unexported, are deeply equal.
//
// Interface values are deeply equal if they hold deeply equal concrete values.
//
// Pointer values are deeply equal if they are equal using Go's == operator
// or if they point to deeply equal values.
//
// Slice values are deeply equal when all of the following are true:
// they are both nil, one is nil and the other is empty or vice-versa,
// they have the same length, and either they point to the same initial entry of the same array
// (that is, &x[0] == &y[0]) or their corresponding elements (up to length) are deeply equal.
//
// Other values - numbers, bools, strings, and channels - are deeply equal
// if they are equal using Go's == operator.
//
// In general DeepEqual is a recursive relaxation of Go's == operator.
// However, this idea is impossible to implement without some inconsistency.
// Specifically, it is possible for a value to be unequal to itself,
// either because it is of func type (uncomparable in general)
// or because it is a floating-point NaN value (not equal to itself in floating-point comparison),
// or because it is an array, struct, or interface containing
// such a value.
//
// On the other hand, pointer values are always equal to themselves,
// even if they point at or contain such problematic values,
// because they compare equal using Go's == operator, and that
// is a sufficient condition to be deeply equal, regardless of content.
// DeepEqual has been defined so that the same short-cut applies
// to slices and maps: if x and y are the same slice or the same map,
// they are deeply equal regardless of content.
//
// As DeepEqual traverses the data values it may find a cycle. The
// second and subsequent times that DeepEqual compares two pointer
// values that have been compared before, it treats the values as
// equal rather than examining the values to which they point.
// This ensures that DeepEqual terminates.
//
// Credits go to the Go team as this is an extension of the official Go source code's
// reflect.DeepEqual function to handle special SSZ edge cases.
func DeepEqual(x, y interface{}) bool {
if x == nil || y == nil {
return x == y
}
v1 := reflect.ValueOf(x)
v2 := reflect.ValueOf(y)
if v1.Type() != v2.Type() {
return false
}
if IsProto(x) && IsProto(y) {
// Exclude unexported fields for protos.
return deepValueEqualExportedOnly(v1, v2, make(map[visit]bool), 0)
}
return deepValueEqual(v1, v2, make(map[visit]bool), 0)
}
func IsProto(item interface{}) bool {
typ := reflect.TypeOf(item)
kind := typ.Kind()
if kind != reflect.Slice && kind != reflect.Array && kind != reflect.Map {
_, ok := item.(proto.Message)
return ok
}
elemTyp := typ.Elem()
modelType := reflect.TypeOf((*proto.Message)(nil)).Elem()
return elemTyp.Implements(modelType)
}

View File

@@ -0,0 +1,133 @@
package ssz_test
import (
"testing"
"github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
)
func TestDeepEqualBasicTypes(t *testing.T) {
assert.Equal(t, true, ssz.DeepEqual(true, true))
assert.Equal(t, false, ssz.DeepEqual(true, false))
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222)))
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111)))
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890)))
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210)))
assert.Equal(t, true, ssz.DeepEqual("hello", "hello"))
assert.Equal(t, false, ssz.DeepEqual("hello", "world"))
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
var nilSlice1, nilSlice2 []byte
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2))
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{}))
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
}
func TestDeepEqualStructs(t *testing.T) {
type Store struct {
V1 uint64
V2 []byte
}
store1 := Store{uint64(1234), nil}
store2 := Store{uint64(1234), []byte{}}
store3 := Store{uint64(4321), []byte{}}
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
}
func TestDeepEqualStructs_Unexported(t *testing.T) {
type Store struct {
V1 uint64
V2 []byte
dontIgnoreMe string
}
store1 := Store{uint64(1234), nil, "hi there"}
store2 := Store{uint64(1234), []byte{}, "hi there"}
store3 := Store{uint64(4321), []byte{}, "wow"}
store4 := Store{uint64(4321), []byte{}, "bow wow"}
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
assert.Equal(t, false, ssz.DeepEqual(store3, store4))
}
func TestDeepEqualProto(t *testing.T) {
var fork1, fork2 *ethpb.Fork
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2))
fork1 = &ethpb.Fork{
PreviousVersion: []byte{123},
CurrentVersion: []byte{124},
Epoch: 1234567890,
}
fork2 = &ethpb.Fork{
PreviousVersion: []byte{123},
CurrentVersion: []byte{125},
Epoch: 1234567890,
}
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1))
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2))
checkpoint1 := &ethpb.Checkpoint{
Epoch: 1234567890,
Root: []byte{},
}
checkpoint2 := &ethpb.Checkpoint{
Epoch: 1234567890,
Root: nil,
}
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2))
}
func Test_IsProto(t *testing.T) {
tests := []struct {
name string
item interface{}
want bool
}{
{
name: "uint64",
item: 0,
want: false,
},
{
name: "string",
item: "foobar cheese",
want: false,
},
{
name: "uint64 array",
item: []uint64{1, 2, 3, 4, 5, 6},
want: false,
},
{
name: "Attestation",
item: &ethpb.Attestation{},
want: true,
},
{
name: "Array of attestations",
item: []*ethpb.Attestation{},
want: true,
},
{
name: "Map of attestations",
item: make(map[uint64]*ethpb.Attestation),
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ssz.IsProto(tt.item); got != tt.want {
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
}
})
}
}

51
encoding/ssz/hashers.go Normal file
View File

@@ -0,0 +1,51 @@
package ssz
import "encoding/binary"
// HashFn is the generic hash function signature.
type HashFn func(input []byte) [32]byte
// Hasher describes an interface through which we can
// perform hash operations on byte arrays,indices,etc.
type Hasher interface {
Hash(a []byte) [32]byte
Combi(a [32]byte, b [32]byte) [32]byte
MixIn(a [32]byte, i uint64) [32]byte
}
// HasherFunc defines a structure to hold a hash function and can be used for multiple rounds of
// hashing.
type HasherFunc struct {
b [64]byte
hashFunc HashFn
}
// NewHasherFunc is the constructor for the object
// that fulfills the Hasher interface.
func NewHasherFunc(h HashFn) *HasherFunc {
return &HasherFunc{
b: [64]byte{},
hashFunc: h,
}
}
// Hash utilizes the provided hash function for
// the object.
func (h *HasherFunc) Hash(a []byte) [32]byte {
return h.hashFunc(a)
}
// Combi appends the two inputs and hashes them.
func (h *HasherFunc) Combi(a, b [32]byte) [32]byte {
copy(h.b[:32], a[:])
copy(h.b[32:], b[:])
return h.Hash(h.b[:])
}
// MixIn works like Combi, but using an integer as the second input.
func (h *HasherFunc) MixIn(a [32]byte, i uint64) [32]byte {
copy(h.b[:32], a[:])
copy(h.b[32:], make([]byte, 32))
binary.LittleEndian.PutUint64(h.b[32:], i)
return h.Hash(h.b[:])
}

View File

@@ -0,0 +1,35 @@
package ssz_test
import (
"testing"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
)
func TestHash(t *testing.T) {
byteSlice := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
expected := [32]byte{71, 228, 238, 127, 33, 31, 115, 38, 93, 209, 118, 88, 246, 226, 28, 19, 24, 189, 108, 129, 243, 117, 152, 226, 10, 39, 86, 41, 149, 66, 239, 207}
result := hasher.Hash(byteSlice)
assert.Equal(t, expected, result)
}
func TestCombi(t *testing.T) {
byteSlice1 := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
byteSlice2 := [32]byte{32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
expected := [32]byte{203, 73, 0, 148, 142, 9, 145, 147, 186, 232, 143, 117, 95, 44, 38, 46, 102, 69, 101, 74, 50, 37, 87, 189, 40, 196, 203, 140, 19, 233, 161, 225}
result := hasher.Combi(byteSlice1, byteSlice2)
assert.Equal(t, expected, result)
}
func TestMixIn(t *testing.T) {
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
intToAdd := uint64(33)
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
expected := [32]byte{170, 90, 0, 249, 34, 60, 140, 68, 77, 51, 218, 139, 54, 119, 179, 238, 80, 72, 13, 20, 212, 218, 124, 215, 68, 122, 214, 157, 178, 85, 225, 213}
result := hasher.MixIn(byteSlice, intToAdd)
assert.Equal(t, expected, result)
}

128
encoding/ssz/helpers.go Normal file
View File

@@ -0,0 +1,128 @@
// Package ssz defines HashTreeRoot utility functions.
package ssz
import (
"bytes"
"encoding/binary"
"github.com/minio/sha256-simd"
"github.com/pkg/errors"
"github.com/prysmaticlabs/go-bitfield"
)
const bytesPerChunk = 32
// BitlistRoot returns the mix in length of a bitwise Merkleized bitfield.
func BitlistRoot(hasher HashFn, bfield bitfield.Bitfield, maxCapacity uint64) ([32]byte, error) {
limit := (maxCapacity + 255) / 256
if bfield == nil || bfield.Len() == 0 {
length := make([]byte, 32)
root, err := BitwiseMerkleize(hasher, [][]byte{}, 0, limit)
if err != nil {
return [32]byte{}, err
}
return MixInLength(root, length), nil
}
chunks, err := Pack([][]byte{bfield.Bytes()})
if err != nil {
return [32]byte{}, err
}
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, bfield.Len()); err != nil {
return [32]byte{}, err
}
output := make([]byte, 32)
copy(output, buf.Bytes())
root, err := BitwiseMerkleize(hasher, chunks, uint64(len(chunks)), limit)
if err != nil {
return [32]byte{}, err
}
return MixInLength(root, output), nil
}
// BitwiseMerkleize - given ordered BYTES_PER_CHUNK-byte chunks, if necessary utilize
// zero chunks so that the number of chunks is a power of two, Merkleize the chunks,
// and return the root.
// Note that merkleize on a single chunk is simply that chunk, i.e. the identity
// when the number of chunks is one.
func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]byte, error) {
if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
}
hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
return Merkleize(hashFn, count, limit, leafIndexer), nil
}
// BitwiseMerkleizeArrays is used when a set of 32-byte root chunks are provided.
func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint64) ([32]byte, error) {
if count > limit {
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
}
hashFn := NewHasherFunc(hasher)
leafIndexer := func(i uint64) []byte {
return chunks[i][:]
}
return Merkleize(hashFn, count, limit, leafIndexer), nil
}
// Pack a given byte array's final chunk with zeroes if needed.
func Pack(serializedItems [][]byte) ([][]byte, error) {
areAllEmpty := true
for _, item := range serializedItems {
if !bytes.Equal(item, []byte{}) {
areAllEmpty = false
break
}
}
// If there are no items, we return an empty chunk.
if len(serializedItems) == 0 || areAllEmpty {
emptyChunk := make([]byte, bytesPerChunk)
return [][]byte{emptyChunk}, nil
} else if len(serializedItems[0]) == bytesPerChunk {
// If each item has exactly BYTES_PER_CHUNK length, we return the list of serialized items.
return serializedItems, nil
}
// We flatten the list in order to pack its items into byte chunks correctly.
var orderedItems []byte
for _, item := range serializedItems {
orderedItems = append(orderedItems, item...)
}
numItems := len(orderedItems)
var chunks [][]byte
for i := 0; i < numItems; i += bytesPerChunk {
j := i + bytesPerChunk
// We create our upper bound index of the chunk, if it is greater than numItems,
// we set it as numItems itself.
if j > numItems {
j = numItems
}
// We create chunks from the list of items based on the
// indices determined above.
chunks = append(chunks, orderedItems[i:j])
}
// Right-pad the last chunk with zero bytes if it does not
// have length bytesPerChunk.
lastChunk := chunks[len(chunks)-1]
for len(lastChunk) < bytesPerChunk {
lastChunk = append(lastChunk, 0)
}
chunks[len(chunks)-1] = lastChunk
return chunks, nil
}
// MixInLength appends hash length to root
func MixInLength(root [32]byte, length []byte) [32]byte {
var hash [32]byte
h := sha256.New()
h.Write(root[:])
h.Write(length)
// The hash interface never returns an error, for that reason
// we are not handling the error below. For reference, it is
// stated here https://golang.org/pkg/hash/#Hash
// #nosec G104
h.Sum(hash[:0])
return hash
}

View File

@@ -0,0 +1,103 @@
package ssz_test
import (
"testing"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
"github.com/prysmaticlabs/prysm/shared/testutil/require"
)
const merkleizingListLimitError = "merkleizing list that is too large, over limit"
func TestBitlistRoot(t *testing.T) {
hasher := hash.CustomSHA256Hasher()
capacity := uint64(10)
bfield := bitfield.NewBitlist(capacity)
expected := [32]byte{176, 76, 194, 203, 142, 166, 117, 79, 148, 194, 231, 64, 60, 245, 142, 32, 201, 2, 58, 152, 53, 12, 132, 40, 41, 102, 224, 189, 103, 41, 211, 202}
result, err := ssz.BitlistRoot(hasher, bfield, capacity)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestBitwiseMerkleize(t *testing.T) {
hasher := hash.CustomSHA256Hasher()
chunks := [][]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
}
count := uint64(2)
limit := uint64(2)
expected := [32]byte{194, 32, 213, 52, 220, 127, 18, 240, 43, 151, 19, 79, 188, 175, 142, 177, 208, 46, 96, 20, 18, 231, 208, 29, 120, 102, 122, 17, 46, 31, 155, 30}
result, err := ssz.BitwiseMerkleize(hasher, chunks, count, limit)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestBitwiseMerkleizeOverLimit(t *testing.T) {
hasher := hash.CustomSHA256Hasher()
chunks := [][]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
}
count := uint64(2)
limit := uint64(1)
_, err := ssz.BitwiseMerkleize(hasher, chunks, count, limit)
assert.ErrorContains(t, merkleizingListLimitError, err)
}
func TestBitwiseMerkleizeArrays(t *testing.T) {
hasher := hash.CustomSHA256Hasher()
chunks := [][32]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
{33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 62, 63, 64},
}
count := uint64(2)
limit := uint64(2)
expected := [32]byte{138, 81, 210, 194, 151, 231, 249, 241, 64, 118, 209, 58, 145, 109, 225, 89, 118, 110, 159, 220, 193, 183, 203, 124, 166, 24, 65, 26, 160, 215, 233, 219}
result, err := ssz.BitwiseMerkleizeArrays(hasher, chunks, count, limit)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestBitwiseMerkleizeArraysOverLimit(t *testing.T) {
hasher := hash.CustomSHA256Hasher()
chunks := [][32]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
{33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 62, 63, 64},
}
count := uint64(2)
limit := uint64(1)
_, err := ssz.BitwiseMerkleizeArrays(hasher, chunks, count, limit)
assert.ErrorContains(t, merkleizingListLimitError, err)
}
func TestPack(t *testing.T) {
byteSlice2D := [][]byte{
{1, 2, 3, 4, 5, 6, 7, 8, 9},
{1, 1, 2, 3, 5, 8, 13, 21, 34},
}
expected := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2, 3, 5, 8, 13, 21, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
result, err := ssz.Pack(byteSlice2D)
require.NoError(t, err)
assert.Equal(t, len(expected), len(result[0]))
for i, v := range expected {
assert.DeepEqual(t, v, result[0][i])
}
}
func TestMixInLength(t *testing.T) {
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
length := []byte{1, 2, 3}
expected := [32]byte{105, 60, 167, 169, 197, 220, 122, 99, 59, 14, 250, 12, 251, 62, 135, 239, 29, 68, 140, 1, 6, 36, 207, 44, 64, 221, 76, 230, 237, 218, 150, 88}
result := ssz.MixInLength(byteSlice, length)
assert.Equal(t, expected, result)
}

92
encoding/ssz/htrutils.go Normal file
View File

@@ -0,0 +1,92 @@
package ssz
import (
"bytes"
"encoding/binary"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/crypto/hash"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
"github.com/prysmaticlabs/prysm/shared/params"
)
// Uint64Root computes the HashTreeRoot Merkleization of
// a simple uint64 value according to the Ethereum
// Simple Serialize specification.
func Uint64Root(val uint64) [32]byte {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, val)
root := bytesutil.ToBytes32(buf)
return root
}
// ForkRoot computes the HashTreeRoot Merkleization of
// a Fork struct value according to the Ethereum
// Simple Serialize specification.
func ForkRoot(fork *ethpb.Fork) ([32]byte, error) {
fieldRoots := make([][]byte, 3)
if fork != nil {
prevRoot := bytesutil.ToBytes32(fork.PreviousVersion)
fieldRoots[0] = prevRoot[:]
currRoot := bytesutil.ToBytes32(fork.CurrentVersion)
fieldRoots[1] = currRoot[:]
forkEpochBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(forkEpochBuf, uint64(fork.Epoch))
epochRoot := bytesutil.ToBytes32(forkEpochBuf)
fieldRoots[2] = epochRoot[:]
}
return BitwiseMerkleize(hash.CustomSHA256Hasher(), fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
}
// CheckpointRoot computes the HashTreeRoot Merkleization of
// a InitWithReset struct value according to the Ethereum
// Simple Serialize specification.
func CheckpointRoot(hasher HashFn, checkpoint *ethpb.Checkpoint) ([32]byte, error) {
fieldRoots := make([][]byte, 2)
if checkpoint != nil {
epochBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(epochBuf, uint64(checkpoint.Epoch))
epochRoot := bytesutil.ToBytes32(epochBuf)
fieldRoots[0] = epochRoot[:]
ckpRoot := bytesutil.ToBytes32(checkpoint.Root)
fieldRoots[1] = ckpRoot[:]
}
return BitwiseMerkleize(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
}
// ByteArrayRootWithLimit computes the HashTreeRoot Merkleization of
// a list of [32]byte roots according to the Ethereum Simple Serialize
// specification.
func ByteArrayRootWithLimit(roots [][]byte, limit uint64) ([32]byte, error) {
result, err := BitwiseMerkleize(hash.CustomSHA256Hasher(), roots, uint64(len(roots)), limit)
if err != nil {
return [32]byte{}, errors.Wrap(err, "could not compute byte array merkleization")
}
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, uint64(len(roots))); err != nil {
return [32]byte{}, errors.Wrap(err, "could not marshal byte array length")
}
// We need to mix in the length of the slice.
output := make([]byte, 32)
copy(output, buf.Bytes())
mixedLen := MixInLength(result, output)
return mixedLen, nil
}
// SlashingsRoot computes the HashTreeRoot Merkleization of
// a list of uint64 slashing values according to the Ethereum
// Simple Serialize specification.
func SlashingsRoot(slashings []uint64) ([32]byte, error) {
slashingMarshaling := make([][]byte, params.BeaconConfig().EpochsPerSlashingsVector)
for i := 0; i < len(slashings) && i < len(slashingMarshaling); i++ {
slashBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(slashBuf, slashings[i])
slashingMarshaling[i] = slashBuf
}
slashingChunks, err := Pack(slashingMarshaling)
if err != nil {
return [32]byte{}, errors.Wrap(err, "could not pack slashings into chunks")
}
return BitwiseMerkleize(hash.CustomSHA256Hasher(), slashingChunks, uint64(len(slashingChunks)), uint64(len(slashingChunks)))
}

View File

@@ -0,0 +1,63 @@
package ssz_test
import (
"testing"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/ssz"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
"github.com/prysmaticlabs/prysm/shared/testutil/require"
)
func TestUint64Root(t *testing.T) {
uintVal := uint64(1234567890)
expected := [32]byte{210, 2, 150, 73, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
result := ssz.Uint64Root(uintVal)
assert.Equal(t, expected, result)
}
func TestForkRoot(t *testing.T) {
testFork := ethpb.Fork{
PreviousVersion: []byte{123},
CurrentVersion: []byte{124},
Epoch: 1234567890,
}
expected := [32]byte{19, 46, 77, 103, 92, 175, 247, 33, 100, 64, 17, 111, 199, 145, 69, 38, 217, 112, 6, 16, 149, 201, 225, 144, 192, 228, 197, 172, 157, 78, 114, 140}
result, err := ssz.ForkRoot(&testFork)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestCheckPointRoot(t *testing.T) {
testHasher := hash.CustomSHA256Hasher()
testCheckpoint := ethpb.Checkpoint{
Epoch: 1234567890,
Root: []byte{222},
}
expected := [32]byte{228, 65, 39, 109, 183, 249, 167, 232, 125, 239, 25, 155, 207, 4, 84, 174, 176, 229, 175, 224, 62, 33, 215, 254, 170, 220, 132, 65, 246, 128, 68, 194}
result, err := ssz.CheckpointRoot(testHasher, &testCheckpoint)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestByteArrayRootWithLimit(t *testing.T) {
testHistoricalRoots := [][]byte{{123}, {234}}
expected := [32]byte{70, 204, 150, 196, 89, 138, 190, 205, 65, 207, 120, 166, 179, 247, 147, 20, 29, 133, 117, 116, 151, 234, 129, 32, 22, 15, 79, 178, 98, 73, 132, 152}
result, err := ssz.ByteArrayRootWithLimit(testHistoricalRoots, 16777216)
require.NoError(t, err)
assert.Equal(t, expected, result)
}
func TestSlashingsRoot(t *testing.T) {
testSlashingsRoot := []uint64{123, 234}
expected := [32]byte{123, 0, 0, 0, 0, 0, 0, 0, 234, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
result, err := ssz.SlashingsRoot(testSlashingsRoot)
require.NoError(t, err)
assert.Equal(t, expected, result)
}

198
encoding/ssz/merkleize.go Normal file
View File

@@ -0,0 +1,198 @@
package ssz
import (
"github.com/prysmaticlabs/prysm/container/trie"
)
// Merkleize.go is mostly a directly copy of the same filename from
// https://github.com/protolambda/zssz/blob/master/merkle/merkleize.go.
// The reason the method is copied instead of imported is due to us using a
// a custom hasher interface for a reduced memory footprint when using
// 'Merkleize'.
const (
mask0 = ^uint64((1 << (1 << iota)) - 1)
mask1
mask2
mask3
mask4
mask5
)
const (
bit0 = uint8(1 << iota)
bit1
bit2
bit3
bit4
bit5
)
// Depth retrieves the appropriate depth for the provided trie size.
func Depth(v uint64) (out uint8) {
// bitmagic: binary search through a uint32, offset down by 1 to not round powers of 2 up.
// Then adding 1 to it to not get the index of the first bit, but the length of the bits (depth of tree)
// Zero is a special case, it has a 0 depth.
// Example:
// (in out): (0 0), (1 1), (2 1), (3 2), (4 2), (5 3), (6 3), (7 3), (8 3), (9 4)
if v == 0 {
return 0
}
v--
if v&mask5 != 0 {
v >>= bit5
out |= bit5
}
if v&mask4 != 0 {
v >>= bit4
out |= bit4
}
if v&mask3 != 0 {
v >>= bit3
out |= bit3
}
if v&mask2 != 0 {
v >>= bit2
out |= bit2
}
if v&mask1 != 0 {
v >>= bit1
out |= bit1
}
if v&mask0 != 0 {
out |= bit0
}
out++
return
}
// Merkleize with log(N) space allocation
func Merkleize(hasher Hasher, count, limit uint64, leaf func(i uint64) []byte) (out [32]byte) {
if count > limit {
panic("merkleizing list that is too large, over limit")
}
if limit == 0 {
return
}
if limit == 1 {
if count == 1 {
copy(out[:], leaf(0))
}
return
}
depth := Depth(count)
limitDepth := Depth(limit)
tmp := make([][32]byte, limitDepth+1)
j := uint8(0)
hArr := [32]byte{}
h := hArr[:]
merge := func(i uint64) {
// merge back up from bottom to top, as far as we can
for j = 0; ; j++ {
// stop merging when we are in the left side of the next combi
if i&(uint64(1)<<j) == 0 {
// if we are at the count, we want to merge in zero-hashes for padding
if i == count && j < depth {
v := hasher.Combi(hArr, trie.ZeroHashes[j])
copy(h, v[:])
} else {
break
}
} else {
// keep merging up if we are the right side
v := hasher.Combi(tmp[j], hArr)
copy(h, v[:])
}
}
// store the merge result (may be no merge, i.e. bottom leaf node)
copy(tmp[j][:], h)
}
// merge in leaf by leaf.
for i := uint64(0); i < count; i++ {
copy(h, leaf(i))
merge(i)
}
// complement with 0 if empty, or if not the right power of 2
if (uint64(1) << depth) != count {
copy(h, trie.ZeroHashes[0][:])
merge(count)
}
// the next power of two may be smaller than the ultimate virtual size,
// complement with zero-hashes at each depth.
for j := depth; j < limitDepth; j++ {
tmp[j+1] = hasher.Combi(tmp[j], trie.ZeroHashes[j])
}
return tmp[limitDepth]
}
// ConstructProof builds a merkle-branch of the given depth, at the given index (at that depth),
// for a list of leafs of a balanced binary tree.
func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []byte, index uint64) (branch [][32]byte) {
if count > limit {
panic("merkleizing list that is too large, over limit")
}
if index >= limit {
panic("index out of range, over limit")
}
if limit <= 1 {
return
}
depth := Depth(count)
limitDepth := Depth(limit)
branch = append(branch, trie.ZeroHashes[:limitDepth]...)
tmp := make([][32]byte, limitDepth+1)
j := uint8(0)
hArr := [32]byte{}
h := hArr[:]
merge := func(i uint64) {
// merge back up from bottom to top, as far as we can
for j = 0; ; j++ {
// if i is a sibling of index at the given depth,
// and i is the last index of the subtree to that depth,
// then put h into the branch
if (i>>j)^1 == (index>>j) && (((1<<j)-1)&i) == ((1<<j)-1) {
// insert sibling into the proof
branch[j] = hArr
}
// stop merging when we are in the left side of the next combi
if i&(uint64(1)<<j) == 0 {
// if we are at the count, we want to merge in zero-hashes for padding
if i == count && j < depth {
v := hasher.Combi(hArr, trie.ZeroHashes[j])
copy(h, v[:])
} else {
break
}
} else {
// keep merging up if we are the right side
v := hasher.Combi(tmp[j], hArr)
copy(h, v[:])
}
}
// store the merge result (may be no merge, i.e. bottom leaf node)
copy(tmp[j][:], h)
}
// merge in leaf by leaf.
for i := uint64(0); i < count; i++ {
copy(h, leaf(i))
merge(i)
}
// complement with 0 if empty, or if not the right power of 2
if (uint64(1) << depth) != count {
copy(h, trie.ZeroHashes[0][:])
merge(count)
}
return
}

View File

@@ -0,0 +1,114 @@
package ssz_test
import (
"testing"
"github.com/prysmaticlabs/prysm/crypto/hash"
"github.com/prysmaticlabs/prysm/encoding/ssz"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
)
func TestGetDepth(t *testing.T) {
trieSize := uint64(896745231)
expected := uint8(30)
result := ssz.Depth(trieSize)
assert.Equal(t, expected, result)
}
func TestMerkleizeCountGreaterThanLimit(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(2)
limit := uint64(1)
chunks := [][]byte{{}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
// Error if no panic
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic.")
}
}()
ssz.Merkleize(hashFn, count, limit, leafIndexer)
}
func TestMerkleizeLimitAndCountAreZero(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(0)
limit := uint64(0)
chunks := [][]byte{{}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
expected := [32]byte{}
result := ssz.Merkleize(hashFn, count, limit, leafIndexer)
assert.Equal(t, expected, result)
}
func TestMerkleizeNormalPath(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(2)
limit := uint64(3)
chunks := [][]byte{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
expected := [32]byte{95, 27, 253, 237, 215, 58, 147, 198, 175, 194, 180, 231, 154, 130, 205, 68, 146, 112, 225, 86, 6, 103, 186, 82, 7, 142, 33, 189, 174, 56, 199, 173}
result := ssz.Merkleize(hashFn, count, limit, leafIndexer)
assert.Equal(t, expected, result)
}
func TestConstructProofCountGreaterThanLimit(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(2)
limit := uint64(1)
chunks := [][]byte{{}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
index := uint64(0)
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic.")
}
}()
ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
}
func TestConstructProofIndexGreaterThanEqualToLimit(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(1)
limit := uint64(1)
chunks := [][]byte{{}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
index := uint64(1)
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic.")
}
}()
ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
}
func TestConstructProofNormalPath(t *testing.T) {
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
count := uint64(2)
limit := uint64(3)
chunks := [][]byte{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}
leafIndexer := func(i uint64) []byte {
return chunks[i]
}
index := uint64(1)
expected := [][32]byte{
{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{245, 165, 253, 66, 209, 106, 32, 48, 39, 152, 239, 110, 211, 9, 151, 155, 67, 0, 61, 35, 32, 217, 240, 232, 234, 152, 49, 169, 39, 89, 251, 75},
}
result := ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
assert.Equal(t, len(expected), len(result))
for i, v := range expected {
assert.DeepEqual(t, result[i], v)
}
}