mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 23:18:15 -05:00
Add Encoding SSZ Package (#9630)
* ssz package * compile * htrutils * rem pkg doc * fix cloners_test.go * fix circular dep/build issues Co-authored-by: prestonvanloon <preston@prysmaticlabs.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
46
encoding/ssz/BUILD.bazel
Normal file
46
encoding/ssz/BUILD.bazel
Normal file
@@ -0,0 +1,46 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"deep_equal.go",
|
||||
"hashers.go",
|
||||
"helpers.go",
|
||||
"htrutils.go",
|
||||
"merkleize.go",
|
||||
],
|
||||
importpath = "github.com/prysmaticlabs/prysm/encoding/ssz",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//container/trie:go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//shared/bytesutil:go_default_library",
|
||||
"//shared/params:go_default_library",
|
||||
"@com_github_minio_sha256_simd//:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
|
||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||
"@org_golang_google_protobuf//proto:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"deep_equal_test.go",
|
||||
"hashers_test.go",
|
||||
"helpers_test.go",
|
||||
"htrutils_test.go",
|
||||
"merkleize_test.go",
|
||||
],
|
||||
deps = [
|
||||
":go_default_library",
|
||||
"//crypto/hash:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//shared/testutil/assert:go_default_library",
|
||||
"//shared/testutil/require:go_default_library",
|
||||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
|
||||
],
|
||||
)
|
||||
323
encoding/ssz/deep_equal.go
Normal file
323
encoding/ssz/deep_equal.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package ssz
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
types "github.com/prysmaticlabs/eth2-types"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// During deepValueEqual, must keep track of checks that are
|
||||
// in progress. The comparison algorithm assumes that all
|
||||
// checks in progress are true when it reencounters them.
|
||||
// Visited comparisons are stored in a map indexed by visit.
|
||||
type visit struct {
|
||||
a1 unsafe.Pointer /* #nosec G103 */
|
||||
a2 unsafe.Pointer /* #nosec G103 */
|
||||
typ reflect.Type
|
||||
}
|
||||
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//
|
||||
// This file extends Go's reflect.DeepEqual function into a ssz.DeepEqual
|
||||
// function that is compliant with the supported types of ssz and its
|
||||
// intricacies when determining equality of empty values.
|
||||
//
|
||||
// Tests for deep equality using reflected types. The map argument tracks
|
||||
// comparisons that have already been seen, which allows short circuiting on
|
||||
// recursive types.
|
||||
func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
|
||||
if !v1.IsValid() || !v2.IsValid() {
|
||||
return v1.IsValid() == v2.IsValid()
|
||||
}
|
||||
if v1.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
// We want to avoid putting more in the visited map than we need to.
|
||||
// For any possible reference cycle that might be encountered,
|
||||
// hard(t) needs to return true for at least one of the types in the cycle.
|
||||
hard := func(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Slice, reflect.Ptr, reflect.Interface:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
|
||||
addr1 := unsafe.Pointer(v1.UnsafeAddr()) /* #nosec G103 */
|
||||
addr2 := unsafe.Pointer(v2.UnsafeAddr()) /* #nosec G103 */
|
||||
|
||||
if uintptr(addr1) > uintptr(addr2) {
|
||||
// Canonicalize order to reduce number of entries in visited.
|
||||
// Assumes non-moving garbage collector.
|
||||
addr1, addr2 = addr2, addr1
|
||||
}
|
||||
|
||||
// Short circuit if references are already seen.
|
||||
typ := v1.Type()
|
||||
v := visit{addr1, addr2, typ}
|
||||
if visited[v] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Remember for later.
|
||||
visited[v] = true
|
||||
}
|
||||
|
||||
switch v1.Kind() {
|
||||
case reflect.Array:
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Slice:
|
||||
if v1.IsNil() && v2.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
if v1.Len() == 0 && v2.IsNil() {
|
||||
return true
|
||||
}
|
||||
if v1.IsNil() && v2.IsNil() {
|
||||
return true
|
||||
}
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
if v1.Pointer() == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Interface:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return v1.IsNil() == v2.IsNil()
|
||||
}
|
||||
return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
|
||||
case reflect.Ptr:
|
||||
if v1.Pointer() == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
|
||||
case reflect.Struct:
|
||||
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||
if !deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return deepValueBaseTypeEqual(v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func deepValueEqualExportedOnly(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
|
||||
if !v1.IsValid() || !v2.IsValid() {
|
||||
return v1.IsValid() == v2.IsValid()
|
||||
}
|
||||
if v1.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
// We want to avoid putting more in the visited map than we need to.
|
||||
// For any possible reference cycle that might be encountered,
|
||||
// hard(t) needs to return true for at least one of the types in the cycle.
|
||||
hard := func(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Slice, reflect.Ptr, reflect.Interface:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
|
||||
addr1 := unsafe.Pointer(v1.UnsafeAddr()) /* #nosec G103 */
|
||||
addr2 := unsafe.Pointer(v2.UnsafeAddr()) /* #nosec G103 */
|
||||
if uintptr(addr1) > uintptr(addr2) {
|
||||
// Canonicalize order to reduce number of entries in visited.
|
||||
// Assumes non-moving garbage collector.
|
||||
addr1, addr2 = addr2, addr1
|
||||
}
|
||||
|
||||
// Short circuit if references are already seen.
|
||||
typ := v1.Type()
|
||||
v := visit{addr1, addr2, typ}
|
||||
if visited[v] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Remember for later.
|
||||
visited[v] = true
|
||||
}
|
||||
|
||||
switch v1.Kind() {
|
||||
case reflect.Array:
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !deepValueEqualExportedOnly(v1.Index(i), v2.Index(i), visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Slice:
|
||||
if v1.IsNil() && v2.Len() == 0 {
|
||||
return true
|
||||
}
|
||||
if v1.Len() == 0 && v2.IsNil() {
|
||||
return true
|
||||
}
|
||||
if v1.IsNil() && v2.IsNil() {
|
||||
return true
|
||||
}
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
if v1.Pointer() == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !deepValueEqualExportedOnly(v1.Index(i), v2.Index(i), visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Interface:
|
||||
if v1.IsNil() || v2.IsNil() {
|
||||
return v1.IsNil() == v2.IsNil()
|
||||
}
|
||||
return deepValueEqualExportedOnly(v1.Elem(), v2.Elem(), visited, depth+1)
|
||||
case reflect.Ptr:
|
||||
if v1.Pointer() == v2.Pointer() {
|
||||
return true
|
||||
}
|
||||
return deepValueEqualExportedOnly(v1.Elem(), v2.Elem(), visited, depth+1)
|
||||
case reflect.Struct:
|
||||
for i, n := 0, v1.NumField(); i < n; i++ {
|
||||
v1Field := v1.Field(i)
|
||||
v2Field := v2.Field(i)
|
||||
if !v1Field.CanInterface() || !v2Field.CanInterface() {
|
||||
// Continue for unexported fields, since they cannot be read anyways.
|
||||
continue
|
||||
}
|
||||
if !deepValueEqualExportedOnly(v1Field, v2Field, visited, depth+1) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
return deepValueBaseTypeEqual(v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
func deepValueBaseTypeEqual(v1, v2 reflect.Value) bool {
|
||||
switch v1.Kind() {
|
||||
case reflect.String:
|
||||
return v1.String() == v2.String()
|
||||
case reflect.Uint64:
|
||||
switch v1.Type().Name() {
|
||||
case "Epoch":
|
||||
return v1.Interface().(types.Epoch) == v2.Interface().(types.Epoch)
|
||||
case "Slot":
|
||||
return v1.Interface().(types.Slot) == v2.Interface().(types.Slot)
|
||||
case "ValidatorIndex":
|
||||
return v1.Interface().(types.ValidatorIndex) == v2.Interface().(types.ValidatorIndex)
|
||||
case "CommitteeIndex":
|
||||
return v1.Interface().(types.CommitteeIndex) == v2.Interface().(types.CommitteeIndex)
|
||||
}
|
||||
return v1.Interface().(uint64) == v2.Interface().(uint64)
|
||||
case reflect.Uint32:
|
||||
return v1.Interface().(uint32) == v2.Interface().(uint32)
|
||||
case reflect.Int32:
|
||||
return v1.Interface().(int32) == v2.Interface().(int32)
|
||||
case reflect.Uint16:
|
||||
return v1.Interface().(uint16) == v2.Interface().(uint16)
|
||||
case reflect.Uint8:
|
||||
return v1.Interface().(uint8) == v2.Interface().(uint8)
|
||||
case reflect.Bool:
|
||||
return v1.Interface().(bool) == v2.Interface().(bool)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DeepEqual reports whether two SSZ-able values x and y are ``deeply equal,'' defined as follows:
|
||||
// Two values of identical type are deeply equal if one of the following cases applies:
|
||||
//
|
||||
// Values of distinct types are never deeply equal.
|
||||
//
|
||||
// Array values are deeply equal when their corresponding elements are deeply equal.
|
||||
//
|
||||
// Struct values are deeply equal if their corresponding fields,
|
||||
// both exported and unexported, are deeply equal.
|
||||
//
|
||||
// Interface values are deeply equal if they hold deeply equal concrete values.
|
||||
//
|
||||
// Pointer values are deeply equal if they are equal using Go's == operator
|
||||
// or if they point to deeply equal values.
|
||||
//
|
||||
// Slice values are deeply equal when all of the following are true:
|
||||
// they are both nil, one is nil and the other is empty or vice-versa,
|
||||
// they have the same length, and either they point to the same initial entry of the same array
|
||||
// (that is, &x[0] == &y[0]) or their corresponding elements (up to length) are deeply equal.
|
||||
//
|
||||
// Other values - numbers, bools, strings, and channels - are deeply equal
|
||||
// if they are equal using Go's == operator.
|
||||
//
|
||||
// In general DeepEqual is a recursive relaxation of Go's == operator.
|
||||
// However, this idea is impossible to implement without some inconsistency.
|
||||
// Specifically, it is possible for a value to be unequal to itself,
|
||||
// either because it is of func type (uncomparable in general)
|
||||
// or because it is a floating-point NaN value (not equal to itself in floating-point comparison),
|
||||
// or because it is an array, struct, or interface containing
|
||||
// such a value.
|
||||
//
|
||||
// On the other hand, pointer values are always equal to themselves,
|
||||
// even if they point at or contain such problematic values,
|
||||
// because they compare equal using Go's == operator, and that
|
||||
// is a sufficient condition to be deeply equal, regardless of content.
|
||||
// DeepEqual has been defined so that the same short-cut applies
|
||||
// to slices and maps: if x and y are the same slice or the same map,
|
||||
// they are deeply equal regardless of content.
|
||||
//
|
||||
// As DeepEqual traverses the data values it may find a cycle. The
|
||||
// second and subsequent times that DeepEqual compares two pointer
|
||||
// values that have been compared before, it treats the values as
|
||||
// equal rather than examining the values to which they point.
|
||||
// This ensures that DeepEqual terminates.
|
||||
//
|
||||
// Credits go to the Go team as this is an extension of the official Go source code's
|
||||
// reflect.DeepEqual function to handle special SSZ edge cases.
|
||||
func DeepEqual(x, y interface{}) bool {
|
||||
if x == nil || y == nil {
|
||||
return x == y
|
||||
}
|
||||
v1 := reflect.ValueOf(x)
|
||||
v2 := reflect.ValueOf(y)
|
||||
if v1.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
if IsProto(x) && IsProto(y) {
|
||||
// Exclude unexported fields for protos.
|
||||
return deepValueEqualExportedOnly(v1, v2, make(map[visit]bool), 0)
|
||||
}
|
||||
return deepValueEqual(v1, v2, make(map[visit]bool), 0)
|
||||
}
|
||||
|
||||
func IsProto(item interface{}) bool {
|
||||
typ := reflect.TypeOf(item)
|
||||
kind := typ.Kind()
|
||||
if kind != reflect.Slice && kind != reflect.Array && kind != reflect.Map {
|
||||
_, ok := item.(proto.Message)
|
||||
return ok
|
||||
}
|
||||
elemTyp := typ.Elem()
|
||||
modelType := reflect.TypeOf((*proto.Message)(nil)).Elem()
|
||||
return elemTyp.Implements(modelType)
|
||||
}
|
||||
133
encoding/ssz/deep_equal_test.go
Normal file
133
encoding/ssz/deep_equal_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package ssz_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
)
|
||||
|
||||
func TestDeepEqualBasicTypes(t *testing.T) {
|
||||
assert.Equal(t, true, ssz.DeepEqual(true, true))
|
||||
assert.Equal(t, false, ssz.DeepEqual(true, false))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual(byte(222), byte(222)))
|
||||
assert.Equal(t, false, ssz.DeepEqual(byte(222), byte(111)))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual(uint64(1234567890), uint64(1234567890)))
|
||||
assert.Equal(t, false, ssz.DeepEqual(uint64(1234567890), uint64(987653210)))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual("hello", "hello"))
|
||||
assert.Equal(t, false, ssz.DeepEqual("hello", "world"))
|
||||
|
||||
assert.Equal(t, true, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 3}))
|
||||
assert.Equal(t, false, ssz.DeepEqual([3]byte{1, 2, 3}, [3]byte{1, 2, 4}))
|
||||
|
||||
var nilSlice1, nilSlice2 []byte
|
||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, nilSlice2))
|
||||
assert.Equal(t, true, ssz.DeepEqual(nilSlice1, []byte{}))
|
||||
assert.Equal(t, true, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 3}))
|
||||
assert.Equal(t, false, ssz.DeepEqual([]byte{1, 2, 3}, []byte{1, 2, 4}))
|
||||
}
|
||||
|
||||
func TestDeepEqualStructs(t *testing.T) {
|
||||
type Store struct {
|
||||
V1 uint64
|
||||
V2 []byte
|
||||
}
|
||||
store1 := Store{uint64(1234), nil}
|
||||
store2 := Store{uint64(1234), []byte{}}
|
||||
store3 := Store{uint64(4321), []byte{}}
|
||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
||||
}
|
||||
|
||||
func TestDeepEqualStructs_Unexported(t *testing.T) {
|
||||
type Store struct {
|
||||
V1 uint64
|
||||
V2 []byte
|
||||
dontIgnoreMe string
|
||||
}
|
||||
store1 := Store{uint64(1234), nil, "hi there"}
|
||||
store2 := Store{uint64(1234), []byte{}, "hi there"}
|
||||
store3 := Store{uint64(4321), []byte{}, "wow"}
|
||||
store4 := Store{uint64(4321), []byte{}, "bow wow"}
|
||||
assert.Equal(t, true, ssz.DeepEqual(store1, store2))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store1, store3))
|
||||
assert.Equal(t, false, ssz.DeepEqual(store3, store4))
|
||||
}
|
||||
|
||||
func TestDeepEqualProto(t *testing.T) {
|
||||
var fork1, fork2 *ethpb.Fork
|
||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork2))
|
||||
|
||||
fork1 = ðpb.Fork{
|
||||
PreviousVersion: []byte{123},
|
||||
CurrentVersion: []byte{124},
|
||||
Epoch: 1234567890,
|
||||
}
|
||||
fork2 = ðpb.Fork{
|
||||
PreviousVersion: []byte{123},
|
||||
CurrentVersion: []byte{125},
|
||||
Epoch: 1234567890,
|
||||
}
|
||||
assert.Equal(t, true, ssz.DeepEqual(fork1, fork1))
|
||||
assert.Equal(t, false, ssz.DeepEqual(fork1, fork2))
|
||||
|
||||
checkpoint1 := ðpb.Checkpoint{
|
||||
Epoch: 1234567890,
|
||||
Root: []byte{},
|
||||
}
|
||||
checkpoint2 := ðpb.Checkpoint{
|
||||
Epoch: 1234567890,
|
||||
Root: nil,
|
||||
}
|
||||
assert.Equal(t, true, ssz.DeepEqual(checkpoint1, checkpoint2))
|
||||
}
|
||||
|
||||
func Test_IsProto(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
item interface{}
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "uint64",
|
||||
item: 0,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
item: "foobar cheese",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "uint64 array",
|
||||
item: []uint64{1, 2, 3, 4, 5, 6},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Attestation",
|
||||
item: ðpb.Attestation{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Array of attestations",
|
||||
item: []*ethpb.Attestation{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Map of attestations",
|
||||
item: make(map[uint64]*ethpb.Attestation),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ssz.IsProto(tt.item); got != tt.want {
|
||||
t.Errorf("isProtoSlice() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
51
encoding/ssz/hashers.go
Normal file
51
encoding/ssz/hashers.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ssz
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// HashFn is the generic hash function signature.
|
||||
type HashFn func(input []byte) [32]byte
|
||||
|
||||
// Hasher describes an interface through which we can
|
||||
// perform hash operations on byte arrays,indices,etc.
|
||||
type Hasher interface {
|
||||
Hash(a []byte) [32]byte
|
||||
Combi(a [32]byte, b [32]byte) [32]byte
|
||||
MixIn(a [32]byte, i uint64) [32]byte
|
||||
}
|
||||
|
||||
// HasherFunc defines a structure to hold a hash function and can be used for multiple rounds of
|
||||
// hashing.
|
||||
type HasherFunc struct {
|
||||
b [64]byte
|
||||
hashFunc HashFn
|
||||
}
|
||||
|
||||
// NewHasherFunc is the constructor for the object
|
||||
// that fulfills the Hasher interface.
|
||||
func NewHasherFunc(h HashFn) *HasherFunc {
|
||||
return &HasherFunc{
|
||||
b: [64]byte{},
|
||||
hashFunc: h,
|
||||
}
|
||||
}
|
||||
|
||||
// Hash utilizes the provided hash function for
|
||||
// the object.
|
||||
func (h *HasherFunc) Hash(a []byte) [32]byte {
|
||||
return h.hashFunc(a)
|
||||
}
|
||||
|
||||
// Combi appends the two inputs and hashes them.
|
||||
func (h *HasherFunc) Combi(a, b [32]byte) [32]byte {
|
||||
copy(h.b[:32], a[:])
|
||||
copy(h.b[32:], b[:])
|
||||
return h.Hash(h.b[:])
|
||||
}
|
||||
|
||||
// MixIn works like Combi, but using an integer as the second input.
|
||||
func (h *HasherFunc) MixIn(a [32]byte, i uint64) [32]byte {
|
||||
copy(h.b[:32], a[:])
|
||||
copy(h.b[32:], make([]byte, 32))
|
||||
binary.LittleEndian.PutUint64(h.b[32:], i)
|
||||
return h.Hash(h.b[:])
|
||||
}
|
||||
35
encoding/ssz/hashers_test.go
Normal file
35
encoding/ssz/hashers_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ssz_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
)
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
byteSlice := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
expected := [32]byte{71, 228, 238, 127, 33, 31, 115, 38, 93, 209, 118, 88, 246, 226, 28, 19, 24, 189, 108, 129, 243, 117, 152, 226, 10, 39, 86, 41, 149, 66, 239, 207}
|
||||
result := hasher.Hash(byteSlice)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCombi(t *testing.T) {
|
||||
byteSlice1 := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
||||
byteSlice2 := [32]byte{32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
|
||||
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
expected := [32]byte{203, 73, 0, 148, 142, 9, 145, 147, 186, 232, 143, 117, 95, 44, 38, 46, 102, 69, 101, 74, 50, 37, 87, 189, 40, 196, 203, 140, 19, 233, 161, 225}
|
||||
result := hasher.Combi(byteSlice1, byteSlice2)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestMixIn(t *testing.T) {
|
||||
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
||||
intToAdd := uint64(33)
|
||||
hasher := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
expected := [32]byte{170, 90, 0, 249, 34, 60, 140, 68, 77, 51, 218, 139, 54, 119, 179, 238, 80, 72, 13, 20, 212, 218, 124, 215, 68, 122, 214, 157, 178, 85, 225, 213}
|
||||
result := hasher.MixIn(byteSlice, intToAdd)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
128
encoding/ssz/helpers.go
Normal file
128
encoding/ssz/helpers.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// Package ssz defines HashTreeRoot utility functions.
|
||||
package ssz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/minio/sha256-simd"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
)
|
||||
|
||||
const bytesPerChunk = 32
|
||||
|
||||
// BitlistRoot returns the mix in length of a bitwise Merkleized bitfield.
|
||||
func BitlistRoot(hasher HashFn, bfield bitfield.Bitfield, maxCapacity uint64) ([32]byte, error) {
|
||||
limit := (maxCapacity + 255) / 256
|
||||
if bfield == nil || bfield.Len() == 0 {
|
||||
length := make([]byte, 32)
|
||||
root, err := BitwiseMerkleize(hasher, [][]byte{}, 0, limit)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
return MixInLength(root, length), nil
|
||||
}
|
||||
chunks, err := Pack([][]byte{bfield.Bytes()})
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
if err := binary.Write(buf, binary.LittleEndian, bfield.Len()); err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
output := make([]byte, 32)
|
||||
copy(output, buf.Bytes())
|
||||
root, err := BitwiseMerkleize(hasher, chunks, uint64(len(chunks)), limit)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
return MixInLength(root, output), nil
|
||||
}
|
||||
|
||||
// BitwiseMerkleize - given ordered BYTES_PER_CHUNK-byte chunks, if necessary utilize
|
||||
// zero chunks so that the number of chunks is a power of two, Merkleize the chunks,
|
||||
// and return the root.
|
||||
// Note that merkleize on a single chunk is simply that chunk, i.e. the identity
|
||||
// when the number of chunks is one.
|
||||
func BitwiseMerkleize(hasher HashFn, chunks [][]byte, count, limit uint64) ([32]byte, error) {
|
||||
if count > limit {
|
||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||
}
|
||||
hashFn := NewHasherFunc(hasher)
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
return Merkleize(hashFn, count, limit, leafIndexer), nil
|
||||
}
|
||||
|
||||
// BitwiseMerkleizeArrays is used when a set of 32-byte root chunks are provided.
|
||||
func BitwiseMerkleizeArrays(hasher HashFn, chunks [][32]byte, count, limit uint64) ([32]byte, error) {
|
||||
if count > limit {
|
||||
return [32]byte{}, errors.New("merkleizing list that is too large, over limit")
|
||||
}
|
||||
hashFn := NewHasherFunc(hasher)
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i][:]
|
||||
}
|
||||
return Merkleize(hashFn, count, limit, leafIndexer), nil
|
||||
}
|
||||
|
||||
// Pack a given byte array's final chunk with zeroes if needed.
|
||||
func Pack(serializedItems [][]byte) ([][]byte, error) {
|
||||
areAllEmpty := true
|
||||
for _, item := range serializedItems {
|
||||
if !bytes.Equal(item, []byte{}) {
|
||||
areAllEmpty = false
|
||||
break
|
||||
}
|
||||
}
|
||||
// If there are no items, we return an empty chunk.
|
||||
if len(serializedItems) == 0 || areAllEmpty {
|
||||
emptyChunk := make([]byte, bytesPerChunk)
|
||||
return [][]byte{emptyChunk}, nil
|
||||
} else if len(serializedItems[0]) == bytesPerChunk {
|
||||
// If each item has exactly BYTES_PER_CHUNK length, we return the list of serialized items.
|
||||
return serializedItems, nil
|
||||
}
|
||||
// We flatten the list in order to pack its items into byte chunks correctly.
|
||||
var orderedItems []byte
|
||||
for _, item := range serializedItems {
|
||||
orderedItems = append(orderedItems, item...)
|
||||
}
|
||||
numItems := len(orderedItems)
|
||||
var chunks [][]byte
|
||||
for i := 0; i < numItems; i += bytesPerChunk {
|
||||
j := i + bytesPerChunk
|
||||
// We create our upper bound index of the chunk, if it is greater than numItems,
|
||||
// we set it as numItems itself.
|
||||
if j > numItems {
|
||||
j = numItems
|
||||
}
|
||||
// We create chunks from the list of items based on the
|
||||
// indices determined above.
|
||||
chunks = append(chunks, orderedItems[i:j])
|
||||
}
|
||||
// Right-pad the last chunk with zero bytes if it does not
|
||||
// have length bytesPerChunk.
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
for len(lastChunk) < bytesPerChunk {
|
||||
lastChunk = append(lastChunk, 0)
|
||||
}
|
||||
chunks[len(chunks)-1] = lastChunk
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// MixInLength appends hash length to root
|
||||
func MixInLength(root [32]byte, length []byte) [32]byte {
|
||||
var hash [32]byte
|
||||
h := sha256.New()
|
||||
h.Write(root[:])
|
||||
h.Write(length)
|
||||
// The hash interface never returns an error, for that reason
|
||||
// we are not handling the error below. For reference, it is
|
||||
// stated here https://golang.org/pkg/hash/#Hash
|
||||
// #nosec G104
|
||||
h.Sum(hash[:0])
|
||||
return hash
|
||||
}
|
||||
103
encoding/ssz/helpers_test.go
Normal file
103
encoding/ssz/helpers_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package ssz_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
const merkleizingListLimitError = "merkleizing list that is too large, over limit"
|
||||
|
||||
func TestBitlistRoot(t *testing.T) {
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
capacity := uint64(10)
|
||||
bfield := bitfield.NewBitlist(capacity)
|
||||
expected := [32]byte{176, 76, 194, 203, 142, 166, 117, 79, 148, 194, 231, 64, 60, 245, 142, 32, 201, 2, 58, 152, 53, 12, 132, 40, 41, 102, 224, 189, 103, 41, 211, 202}
|
||||
|
||||
result, err := ssz.BitlistRoot(hasher, bfield, capacity)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestBitwiseMerkleize(t *testing.T) {
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
chunks := [][]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
|
||||
}
|
||||
count := uint64(2)
|
||||
limit := uint64(2)
|
||||
expected := [32]byte{194, 32, 213, 52, 220, 127, 18, 240, 43, 151, 19, 79, 188, 175, 142, 177, 208, 46, 96, 20, 18, 231, 208, 29, 120, 102, 122, 17, 46, 31, 155, 30}
|
||||
|
||||
result, err := ssz.BitwiseMerkleize(hasher, chunks, count, limit)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestBitwiseMerkleizeOverLimit(t *testing.T) {
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
chunks := [][]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
|
||||
{11, 12, 13, 14, 15, 16, 17, 18, 19, 20},
|
||||
}
|
||||
count := uint64(2)
|
||||
limit := uint64(1)
|
||||
|
||||
_, err := ssz.BitwiseMerkleize(hasher, chunks, count, limit)
|
||||
assert.ErrorContains(t, merkleizingListLimitError, err)
|
||||
}
|
||||
|
||||
func TestBitwiseMerkleizeArrays(t *testing.T) {
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
chunks := [][32]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
|
||||
{33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 62, 63, 64},
|
||||
}
|
||||
count := uint64(2)
|
||||
limit := uint64(2)
|
||||
expected := [32]byte{138, 81, 210, 194, 151, 231, 249, 241, 64, 118, 209, 58, 145, 109, 225, 89, 118, 110, 159, 220, 193, 183, 203, 124, 166, 24, 65, 26, 160, 215, 233, 219}
|
||||
|
||||
result, err := ssz.BitwiseMerkleizeArrays(hasher, chunks, count, limit)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestBitwiseMerkleizeArraysOverLimit(t *testing.T) {
|
||||
hasher := hash.CustomSHA256Hasher()
|
||||
chunks := [][32]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32},
|
||||
{33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 62, 62, 63, 64},
|
||||
}
|
||||
count := uint64(2)
|
||||
limit := uint64(1)
|
||||
|
||||
_, err := ssz.BitwiseMerkleizeArrays(hasher, chunks, count, limit)
|
||||
assert.ErrorContains(t, merkleizingListLimitError, err)
|
||||
}
|
||||
|
||||
func TestPack(t *testing.T) {
|
||||
byteSlice2D := [][]byte{
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
{1, 1, 2, 3, 5, 8, 13, 21, 34},
|
||||
}
|
||||
expected := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2, 3, 5, 8, 13, 21, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
result, err := ssz.Pack(byteSlice2D)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expected), len(result[0]))
|
||||
for i, v := range expected {
|
||||
assert.DeepEqual(t, v, result[0][i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMixInLength(t *testing.T) {
|
||||
byteSlice := [32]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32}
|
||||
length := []byte{1, 2, 3}
|
||||
expected := [32]byte{105, 60, 167, 169, 197, 220, 122, 99, 59, 14, 250, 12, 251, 62, 135, 239, 29, 68, 140, 1, 6, 36, 207, 44, 64, 221, 76, 230, 237, 218, 150, 88}
|
||||
result := ssz.MixInLength(byteSlice, length)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
92
encoding/ssz/htrutils.go
Normal file
92
encoding/ssz/htrutils.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package ssz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/shared/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
)
|
||||
|
||||
// Uint64Root computes the HashTreeRoot Merkleization of
|
||||
// a simple uint64 value according to the Ethereum
|
||||
// Simple Serialize specification.
|
||||
func Uint64Root(val uint64) [32]byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(buf, val)
|
||||
root := bytesutil.ToBytes32(buf)
|
||||
return root
|
||||
}
|
||||
|
||||
// ForkRoot computes the HashTreeRoot Merkleization of
|
||||
// a Fork struct value according to the Ethereum
|
||||
// Simple Serialize specification.
|
||||
func ForkRoot(fork *ethpb.Fork) ([32]byte, error) {
|
||||
fieldRoots := make([][]byte, 3)
|
||||
if fork != nil {
|
||||
prevRoot := bytesutil.ToBytes32(fork.PreviousVersion)
|
||||
fieldRoots[0] = prevRoot[:]
|
||||
currRoot := bytesutil.ToBytes32(fork.CurrentVersion)
|
||||
fieldRoots[1] = currRoot[:]
|
||||
forkEpochBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(forkEpochBuf, uint64(fork.Epoch))
|
||||
epochRoot := bytesutil.ToBytes32(forkEpochBuf)
|
||||
fieldRoots[2] = epochRoot[:]
|
||||
}
|
||||
return BitwiseMerkleize(hash.CustomSHA256Hasher(), fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
||||
}
|
||||
|
||||
// CheckpointRoot computes the HashTreeRoot Merkleization of
|
||||
// a InitWithReset struct value according to the Ethereum
|
||||
// Simple Serialize specification.
|
||||
func CheckpointRoot(hasher HashFn, checkpoint *ethpb.Checkpoint) ([32]byte, error) {
|
||||
fieldRoots := make([][]byte, 2)
|
||||
if checkpoint != nil {
|
||||
epochBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(epochBuf, uint64(checkpoint.Epoch))
|
||||
epochRoot := bytesutil.ToBytes32(epochBuf)
|
||||
fieldRoots[0] = epochRoot[:]
|
||||
ckpRoot := bytesutil.ToBytes32(checkpoint.Root)
|
||||
fieldRoots[1] = ckpRoot[:]
|
||||
}
|
||||
return BitwiseMerkleize(hasher, fieldRoots, uint64(len(fieldRoots)), uint64(len(fieldRoots)))
|
||||
}
|
||||
|
||||
// ByteArrayRootWithLimit computes the HashTreeRoot Merkleization of
|
||||
// a list of [32]byte roots according to the Ethereum Simple Serialize
|
||||
// specification.
|
||||
func ByteArrayRootWithLimit(roots [][]byte, limit uint64) ([32]byte, error) {
|
||||
result, err := BitwiseMerkleize(hash.CustomSHA256Hasher(), roots, uint64(len(roots)), limit)
|
||||
if err != nil {
|
||||
return [32]byte{}, errors.Wrap(err, "could not compute byte array merkleization")
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
if err := binary.Write(buf, binary.LittleEndian, uint64(len(roots))); err != nil {
|
||||
return [32]byte{}, errors.Wrap(err, "could not marshal byte array length")
|
||||
}
|
||||
// We need to mix in the length of the slice.
|
||||
output := make([]byte, 32)
|
||||
copy(output, buf.Bytes())
|
||||
mixedLen := MixInLength(result, output)
|
||||
return mixedLen, nil
|
||||
}
|
||||
|
||||
// SlashingsRoot computes the HashTreeRoot Merkleization of
|
||||
// a list of uint64 slashing values according to the Ethereum
|
||||
// Simple Serialize specification.
|
||||
func SlashingsRoot(slashings []uint64) ([32]byte, error) {
|
||||
slashingMarshaling := make([][]byte, params.BeaconConfig().EpochsPerSlashingsVector)
|
||||
for i := 0; i < len(slashings) && i < len(slashingMarshaling); i++ {
|
||||
slashBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(slashBuf, slashings[i])
|
||||
slashingMarshaling[i] = slashBuf
|
||||
}
|
||||
slashingChunks, err := Pack(slashingMarshaling)
|
||||
if err != nil {
|
||||
return [32]byte{}, errors.Wrap(err, "could not pack slashings into chunks")
|
||||
}
|
||||
return BitwiseMerkleize(hash.CustomSHA256Hasher(), slashingChunks, uint64(len(slashingChunks)), uint64(len(slashingChunks)))
|
||||
}
|
||||
63
encoding/ssz/htrutils_test.go
Normal file
63
encoding/ssz/htrutils_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package ssz_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
func TestUint64Root(t *testing.T) {
|
||||
uintVal := uint64(1234567890)
|
||||
expected := [32]byte{210, 2, 150, 73, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
result := ssz.Uint64Root(uintVal)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestForkRoot(t *testing.T) {
|
||||
testFork := ethpb.Fork{
|
||||
PreviousVersion: []byte{123},
|
||||
CurrentVersion: []byte{124},
|
||||
Epoch: 1234567890,
|
||||
}
|
||||
expected := [32]byte{19, 46, 77, 103, 92, 175, 247, 33, 100, 64, 17, 111, 199, 145, 69, 38, 217, 112, 6, 16, 149, 201, 225, 144, 192, 228, 197, 172, 157, 78, 114, 140}
|
||||
|
||||
result, err := ssz.ForkRoot(&testFork)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCheckPointRoot(t *testing.T) {
|
||||
testHasher := hash.CustomSHA256Hasher()
|
||||
testCheckpoint := ethpb.Checkpoint{
|
||||
Epoch: 1234567890,
|
||||
Root: []byte{222},
|
||||
}
|
||||
expected := [32]byte{228, 65, 39, 109, 183, 249, 167, 232, 125, 239, 25, 155, 207, 4, 84, 174, 176, 229, 175, 224, 62, 33, 215, 254, 170, 220, 132, 65, 246, 128, 68, 194}
|
||||
|
||||
result, err := ssz.CheckpointRoot(testHasher, &testCheckpoint)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestByteArrayRootWithLimit(t *testing.T) {
|
||||
testHistoricalRoots := [][]byte{{123}, {234}}
|
||||
expected := [32]byte{70, 204, 150, 196, 89, 138, 190, 205, 65, 207, 120, 166, 179, 247, 147, 20, 29, 133, 117, 116, 151, 234, 129, 32, 22, 15, 79, 178, 98, 73, 132, 152}
|
||||
|
||||
result, err := ssz.ByteArrayRootWithLimit(testHistoricalRoots, 16777216)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestSlashingsRoot(t *testing.T) {
|
||||
testSlashingsRoot := []uint64{123, 234}
|
||||
expected := [32]byte{123, 0, 0, 0, 0, 0, 0, 0, 234, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
result, err := ssz.SlashingsRoot(testSlashingsRoot)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
198
encoding/ssz/merkleize.go
Normal file
198
encoding/ssz/merkleize.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package ssz
|
||||
|
||||
import (
|
||||
"github.com/prysmaticlabs/prysm/container/trie"
|
||||
)
|
||||
|
||||
// Merkleize.go is mostly a directly copy of the same filename from
|
||||
// https://github.com/protolambda/zssz/blob/master/merkle/merkleize.go.
|
||||
// The reason the method is copied instead of imported is due to us using a
|
||||
// a custom hasher interface for a reduced memory footprint when using
|
||||
// 'Merkleize'.
|
||||
|
||||
const (
|
||||
mask0 = ^uint64((1 << (1 << iota)) - 1)
|
||||
mask1
|
||||
mask2
|
||||
mask3
|
||||
mask4
|
||||
mask5
|
||||
)
|
||||
|
||||
const (
|
||||
bit0 = uint8(1 << iota)
|
||||
bit1
|
||||
bit2
|
||||
bit3
|
||||
bit4
|
||||
bit5
|
||||
)
|
||||
|
||||
// Depth retrieves the appropriate depth for the provided trie size.
|
||||
func Depth(v uint64) (out uint8) {
|
||||
// bitmagic: binary search through a uint32, offset down by 1 to not round powers of 2 up.
|
||||
// Then adding 1 to it to not get the index of the first bit, but the length of the bits (depth of tree)
|
||||
// Zero is a special case, it has a 0 depth.
|
||||
// Example:
|
||||
// (in out): (0 0), (1 1), (2 1), (3 2), (4 2), (5 3), (6 3), (7 3), (8 3), (9 4)
|
||||
if v == 0 {
|
||||
return 0
|
||||
}
|
||||
v--
|
||||
if v&mask5 != 0 {
|
||||
v >>= bit5
|
||||
out |= bit5
|
||||
}
|
||||
if v&mask4 != 0 {
|
||||
v >>= bit4
|
||||
out |= bit4
|
||||
}
|
||||
if v&mask3 != 0 {
|
||||
v >>= bit3
|
||||
out |= bit3
|
||||
}
|
||||
if v&mask2 != 0 {
|
||||
v >>= bit2
|
||||
out |= bit2
|
||||
}
|
||||
if v&mask1 != 0 {
|
||||
v >>= bit1
|
||||
out |= bit1
|
||||
}
|
||||
if v&mask0 != 0 {
|
||||
out |= bit0
|
||||
}
|
||||
out++
|
||||
return
|
||||
}
|
||||
|
||||
// Merkleize with log(N) space allocation
|
||||
func Merkleize(hasher Hasher, count, limit uint64, leaf func(i uint64) []byte) (out [32]byte) {
|
||||
if count > limit {
|
||||
panic("merkleizing list that is too large, over limit")
|
||||
}
|
||||
if limit == 0 {
|
||||
return
|
||||
}
|
||||
if limit == 1 {
|
||||
if count == 1 {
|
||||
copy(out[:], leaf(0))
|
||||
}
|
||||
return
|
||||
}
|
||||
depth := Depth(count)
|
||||
limitDepth := Depth(limit)
|
||||
tmp := make([][32]byte, limitDepth+1)
|
||||
|
||||
j := uint8(0)
|
||||
hArr := [32]byte{}
|
||||
h := hArr[:]
|
||||
|
||||
merge := func(i uint64) {
|
||||
// merge back up from bottom to top, as far as we can
|
||||
for j = 0; ; j++ {
|
||||
// stop merging when we are in the left side of the next combi
|
||||
if i&(uint64(1)<<j) == 0 {
|
||||
// if we are at the count, we want to merge in zero-hashes for padding
|
||||
if i == count && j < depth {
|
||||
v := hasher.Combi(hArr, trie.ZeroHashes[j])
|
||||
copy(h, v[:])
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// keep merging up if we are the right side
|
||||
v := hasher.Combi(tmp[j], hArr)
|
||||
copy(h, v[:])
|
||||
}
|
||||
}
|
||||
// store the merge result (may be no merge, i.e. bottom leaf node)
|
||||
copy(tmp[j][:], h)
|
||||
}
|
||||
|
||||
// merge in leaf by leaf.
|
||||
for i := uint64(0); i < count; i++ {
|
||||
copy(h, leaf(i))
|
||||
merge(i)
|
||||
}
|
||||
|
||||
// complement with 0 if empty, or if not the right power of 2
|
||||
if (uint64(1) << depth) != count {
|
||||
copy(h, trie.ZeroHashes[0][:])
|
||||
merge(count)
|
||||
}
|
||||
|
||||
// the next power of two may be smaller than the ultimate virtual size,
|
||||
// complement with zero-hashes at each depth.
|
||||
for j := depth; j < limitDepth; j++ {
|
||||
tmp[j+1] = hasher.Combi(tmp[j], trie.ZeroHashes[j])
|
||||
}
|
||||
|
||||
return tmp[limitDepth]
|
||||
}
|
||||
|
||||
// ConstructProof builds a merkle-branch of the given depth, at the given index (at that depth),
|
||||
// for a list of leafs of a balanced binary tree.
|
||||
func ConstructProof(hasher Hasher, count, limit uint64, leaf func(i uint64) []byte, index uint64) (branch [][32]byte) {
|
||||
if count > limit {
|
||||
panic("merkleizing list that is too large, over limit")
|
||||
}
|
||||
if index >= limit {
|
||||
panic("index out of range, over limit")
|
||||
}
|
||||
if limit <= 1 {
|
||||
return
|
||||
}
|
||||
depth := Depth(count)
|
||||
limitDepth := Depth(limit)
|
||||
branch = append(branch, trie.ZeroHashes[:limitDepth]...)
|
||||
|
||||
tmp := make([][32]byte, limitDepth+1)
|
||||
|
||||
j := uint8(0)
|
||||
hArr := [32]byte{}
|
||||
h := hArr[:]
|
||||
|
||||
merge := func(i uint64) {
|
||||
// merge back up from bottom to top, as far as we can
|
||||
for j = 0; ; j++ {
|
||||
// if i is a sibling of index at the given depth,
|
||||
// and i is the last index of the subtree to that depth,
|
||||
// then put h into the branch
|
||||
if (i>>j)^1 == (index>>j) && (((1<<j)-1)&i) == ((1<<j)-1) {
|
||||
// insert sibling into the proof
|
||||
branch[j] = hArr
|
||||
}
|
||||
// stop merging when we are in the left side of the next combi
|
||||
if i&(uint64(1)<<j) == 0 {
|
||||
// if we are at the count, we want to merge in zero-hashes for padding
|
||||
if i == count && j < depth {
|
||||
v := hasher.Combi(hArr, trie.ZeroHashes[j])
|
||||
copy(h, v[:])
|
||||
} else {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// keep merging up if we are the right side
|
||||
v := hasher.Combi(tmp[j], hArr)
|
||||
copy(h, v[:])
|
||||
}
|
||||
}
|
||||
// store the merge result (may be no merge, i.e. bottom leaf node)
|
||||
copy(tmp[j][:], h)
|
||||
}
|
||||
|
||||
// merge in leaf by leaf.
|
||||
for i := uint64(0); i < count; i++ {
|
||||
copy(h, leaf(i))
|
||||
merge(i)
|
||||
}
|
||||
|
||||
// complement with 0 if empty, or if not the right power of 2
|
||||
if (uint64(1) << depth) != count {
|
||||
copy(h, trie.ZeroHashes[0][:])
|
||||
merge(count)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
114
encoding/ssz/merkleize_test.go
Normal file
114
encoding/ssz/merkleize_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package ssz_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/crypto/hash"
|
||||
"github.com/prysmaticlabs/prysm/encoding/ssz"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
)
|
||||
|
||||
func TestGetDepth(t *testing.T) {
|
||||
trieSize := uint64(896745231)
|
||||
expected := uint8(30)
|
||||
|
||||
result := ssz.Depth(trieSize)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestMerkleizeCountGreaterThanLimit(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(2)
|
||||
limit := uint64(1)
|
||||
chunks := [][]byte{{}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
// Error if no panic
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic.")
|
||||
}
|
||||
}()
|
||||
ssz.Merkleize(hashFn, count, limit, leafIndexer)
|
||||
}
|
||||
|
||||
func TestMerkleizeLimitAndCountAreZero(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(0)
|
||||
limit := uint64(0)
|
||||
chunks := [][]byte{{}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
expected := [32]byte{}
|
||||
result := ssz.Merkleize(hashFn, count, limit, leafIndexer)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestMerkleizeNormalPath(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(2)
|
||||
limit := uint64(3)
|
||||
chunks := [][]byte{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
expected := [32]byte{95, 27, 253, 237, 215, 58, 147, 198, 175, 194, 180, 231, 154, 130, 205, 68, 146, 112, 225, 86, 6, 103, 186, 82, 7, 142, 33, 189, 174, 56, 199, 173}
|
||||
result := ssz.Merkleize(hashFn, count, limit, leafIndexer)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestConstructProofCountGreaterThanLimit(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(2)
|
||||
limit := uint64(1)
|
||||
chunks := [][]byte{{}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
index := uint64(0)
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic.")
|
||||
}
|
||||
}()
|
||||
ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
|
||||
}
|
||||
|
||||
func TestConstructProofIndexGreaterThanEqualToLimit(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(1)
|
||||
limit := uint64(1)
|
||||
chunks := [][]byte{{}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
index := uint64(1)
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic.")
|
||||
}
|
||||
}()
|
||||
ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
|
||||
}
|
||||
|
||||
func TestConstructProofNormalPath(t *testing.T) {
|
||||
hashFn := ssz.NewHasherFunc(hash.CustomSHA256Hasher())
|
||||
count := uint64(2)
|
||||
limit := uint64(3)
|
||||
chunks := [][]byte{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}
|
||||
leafIndexer := func(i uint64) []byte {
|
||||
return chunks[i]
|
||||
}
|
||||
index := uint64(1)
|
||||
expected := [][32]byte{
|
||||
{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
{245, 165, 253, 66, 209, 106, 32, 48, 39, 152, 239, 110, 211, 9, 151, 155, 67, 0, 61, 35, 32, 217, 240, 232, 234, 152, 49, 169, 39, 89, 251, 75},
|
||||
}
|
||||
result := ssz.ConstructProof(hashFn, count, limit, leafIndexer, index)
|
||||
assert.Equal(t, len(expected), len(result))
|
||||
for i, v := range expected {
|
||||
assert.DeepEqual(t, result[i], v)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user